kjysmu commited on
Commit
b1b35b0
·
verified ·
1 Parent(s): 01af6a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -6
app.py CHANGED
@@ -203,7 +203,8 @@ def resample_waveform(waveform, original_sample_rate, target_sample_rate):
203
 
204
  # return segments
205
 
206
- def split_audio(waveform, sample_rate):
 
207
  segment_samples = segment_duration * sample_rate
208
  total_samples = waveform.size(0)
209
 
@@ -213,14 +214,33 @@ def split_audio(waveform, sample_rate):
213
  if end <= total_samples:
214
  segment = waveform[start:end]
215
  segments.append(segment)
216
-
217
- # In case audio length is shorter than segment length.
218
- if len(segments) == 0:
219
- segment = waveform
220
- segments.append(segment)
 
221
 
222
  return segments
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  def safe_remove_dir(directory):
226
  """
 
203
 
204
  # return segments
205
 
206
+
207
+ def split_audio(waveform, sample_rate, segment_duration=10):
208
  segment_samples = segment_duration * sample_rate
209
  total_samples = waveform.size(0)
210
 
 
214
  if end <= total_samples:
215
  segment = waveform[start:end]
216
  segments.append(segment)
217
+
218
+ # If no full segments were created, pad the short waveform
219
+ if len(segments) == 0:
220
+ pad_length = segment_samples - total_samples
221
+ padded_waveform = torch.nn.functional.pad(waveform, (0, pad_length))
222
+ segments.append(padded_waveform)
223
 
224
  return segments
225
 
226
+ # def split_audio(waveform, sample_rate):
227
+ # segment_samples = segment_duration * sample_rate
228
+ # total_samples = waveform.size(0)
229
+
230
+ # segments = []
231
+ # for start in range(0, total_samples, segment_samples):
232
+ # end = start + segment_samples
233
+ # if end <= total_samples:
234
+ # segment = waveform[start:end]
235
+ # segments.append(segment)
236
+
237
+ # # In case audio length is shorter than segment length.
238
+ # if len(segments) == 0:
239
+ # segment = waveform
240
+ # segments.append(segment)
241
+
242
+ # return segments
243
+
244
 
245
  def safe_remove_dir(directory):
246
  """