kjysmu commited on
Commit
a46ea05
·
verified ·
1 Parent(s): 346d95d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -30
app.py CHANGED
@@ -202,23 +202,24 @@ def resample_waveform(waveform, original_sample_rate, target_sample_rate):
202
  # segments.append(waveform)
203
 
204
  # return segments
205
- def split_audio(waveform, sample_rate):
206
- segment_samples = segment_duration * sample_rate
207
- total_samples = waveform.size(0)
208
 
209
- # Pad if shorter than one segment
210
- if total_samples < segment_samples:
211
- pad_size = segment_samples - total_samples
212
- waveform = torch.nn.functional.pad(waveform, (0, pad_size))
213
 
214
- segments = []
215
- for start in range(0, waveform.size(0), segment_samples):
216
- end = start + segment_samples
217
- if end <= waveform.size(0):
218
- segment = waveform[start:end]
219
- segments.append(segment)
220
 
221
- return segments
 
 
 
 
 
 
 
222
 
223
  # def split_audio(waveform, sample_rate, segment_duration=10):
224
  # segment_samples = segment_duration * sample_rate
@@ -239,23 +240,23 @@ def split_audio(waveform, sample_rate):
239
 
240
  # return segments
241
 
242
- # def split_audio(waveform, sample_rate):
243
- # segment_samples = segment_duration * sample_rate
244
- # total_samples = waveform.size(0)
245
 
246
- # segments = []
247
- # for start in range(0, total_samples, segment_samples):
248
- # end = start + segment_samples
249
- # if end <= total_samples:
250
- # segment = waveform[start:end]
251
- # segments.append(segment)
252
 
253
- # # In case audio length is shorter than segment length.
254
- # if len(segments) == 0:
255
- # segment = waveform
256
- # segments.append(segment)
257
 
258
- # return segments
259
 
260
 
261
  def safe_remove_dir(directory):
@@ -380,8 +381,14 @@ class Music2emo:
380
  waveform = waveform.mean(dim=0).unsqueeze(0)
381
  waveform = waveform.squeeze()
382
  waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate)
383
-
384
- if is_split:
 
 
 
 
 
 
385
  segments = split_audio(waveform, sample_rate)
386
  for i, segment in enumerate(segments):
387
  segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy")
@@ -389,6 +396,15 @@ class Music2emo:
389
  else:
390
  segment_save_path = os.path.join(mert_dir, f"segment_0.npy")
391
  self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
 
 
 
 
 
 
 
 
 
392
 
393
  embeddings = []
394
  layers_to_extract = [5,6]
 
202
  # segments.append(waveform)
203
 
204
  # return segments
 
 
 
205
 
206
+ # def split_audio(waveform, sample_rate):
207
+ # segment_samples = segment_duration * sample_rate
208
+ # total_samples = waveform.size(0)
 
209
 
210
+ # # Pad if shorter than one segment
211
+ # if total_samples < segment_samples:
212
+ # pad_size = segment_samples - total_samples
213
+ # waveform = torch.nn.functional.pad(waveform, (0, pad_size))
 
 
214
 
215
+ # segments = []
216
+ # for start in range(0, waveform.size(0), segment_samples):
217
+ # end = start + segment_samples
218
+ # if end <= waveform.size(0):
219
+ # segment = waveform[start:end]
220
+ # segments.append(segment)
221
+
222
+ # return segments
223
 
224
  # def split_audio(waveform, sample_rate, segment_duration=10):
225
  # segment_samples = segment_duration * sample_rate
 
240
 
241
  # return segments
242
 
243
+ def split_audio(waveform, sample_rate):
244
+ segment_samples = segment_duration * sample_rate
245
+ total_samples = waveform.size(0)
246
 
247
+ segments = []
248
+ for start in range(0, total_samples, segment_samples):
249
+ end = start + segment_samples
250
+ if end <= total_samples:
251
+ segment = waveform[start:end]
252
+ segments.append(segment)
253
 
254
+ # In case audio length is shorter than segment length.
255
+ if len(segments) == 0:
256
+ segment = waveform
257
+ segments.append(segment)
258
 
259
+ return segments
260
 
261
 
262
  def safe_remove_dir(directory):
 
381
  waveform = waveform.mean(dim=0).unsqueeze(0)
382
  waveform = waveform.squeeze()
383
  waveform, sample_rate = resample_waveform(waveform, sample_rate, resample_rate)
384
+
385
+
386
+ # 🔍 Check duration
387
+ duration_sec = waveform.shape[-1] / sample_rate
388
+ is_split = duration_sec <= 30.0
389
+ print(f"Audio duration: {duration_sec:.2f} seconds | is_split = {is_split}")
390
+
391
+ if is_split:
392
  segments = split_audio(waveform, sample_rate)
393
  for i, segment in enumerate(segments):
394
  segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy")
 
396
  else:
397
  segment_save_path = os.path.join(mert_dir, f"segment_0.npy")
398
  self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
399
+
400
+ # if is_split:
401
+ # segments = split_audio(waveform, sample_rate)
402
+ # for i, segment in enumerate(segments):
403
+ # segment_save_path = os.path.join(mert_dir, f"segment_{i}.npy")
404
+ # self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path)
405
+ # else:
406
+ # segment_save_path = os.path.join(mert_dir, f"segment_0.npy")
407
+ # self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
408
 
409
  embeddings = []
410
  layers_to_extract = [5,6]