Porjaz commited on
Commit
81581ab
·
verified ·
1 Parent(s): d5b661b

Update custom_interface_app.py

Browse files
Files changed (1) hide show
  1. custom_interface_app.py +125 -128
custom_interface_app.py CHANGED
@@ -41,7 +41,7 @@ class ASR(Pretrained):
41
  # Forward encoder + decoder
42
  tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
43
  tokens = tokens.to(device)
44
- enc_out, logits, _ = self.mods.whisper(wavs.detach(), tokens.detach())
45
  log_probs = self.hparams.log_softmax(logits)
46
 
47
  hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
@@ -128,30 +128,53 @@ class ASR(Pretrained):
128
 
129
  def classify_file_w2v2(self, waveform, device):
130
  # Load the audio file
 
131
  # waveform, sr = librosa.load(path, sr=16000)
132
 
 
 
 
133
  # Get audio length in seconds
134
- audio_length = len(waveform) / 16000
 
135
 
136
- if audio_length >= 30:
137
- # split audio every 20 seconds
 
 
 
 
138
  segments = []
139
- all_segments = []
140
- max_duration = 30 * 16000 # Maximum segment duration in samples (20 seconds)
141
- num_segments = int(np.ceil(len(waveform) / max_duration))
142
- start = 0
143
- for i in range(num_segments):
144
- end = start + max_duration
145
- if end > len(waveform):
146
- end = len(waveform)
147
  segment_part = waveform[start:end]
148
- segment_len = len(segment_part) / 16000
149
- if segment_len < 1:
150
- continue
151
- segments.append(segment_part)
152
- start = end
153
 
154
- for segment in segments:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  segment_tensor = torch.tensor(segment).to(device)
156
 
157
  # Fake a batch for the segment
@@ -159,171 +182,145 @@ class ASR(Pretrained):
159
  rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
160
 
161
  # Pass the segment through the ASR model
162
- segment_output = self.encode_batch_w2v2(device, batch, rel_length)
163
- segment_output = [" ".join(segment) for segment in segment_output]
164
- all_segments.append(segment_output)
165
-
166
- segments = ""
167
- for segment in all_segments:
168
- segment = segment[0]
169
- segments += segment + " "
170
- return [segments]
171
  else:
172
  waveform = torch.tensor(waveform).to(device)
173
  waveform = waveform.to(device)
174
  # Fake a batch:
175
  batch = waveform.unsqueeze(0)
176
  rel_length = torch.tensor([1.0]).to(device)
177
- outputs = self.encode_batch_w2v2(device, batch, rel_length)
178
- return [" ".join(out) for out in outputs]
179
-
180
 
181
  def classify_file_whisper_mkd(self, waveform, device):
182
  # Load the audio file
 
183
  # waveform, sr = librosa.load(path, sr=16000)
184
 
 
 
 
185
  # Get audio length in seconds
186
- audio_length = len(waveform) / 16000
 
187
 
188
- if audio_length >= 30:
189
- # split audio every 20 seconds
 
 
 
 
190
  segments = []
191
- all_segments = []
192
- max_duration = 30 * 16000 # Maximum segment duration in samples (20 seconds)
193
- num_segments = int(np.ceil(len(waveform) / max_duration))
194
- start = 0
195
- for i in range(num_segments):
196
- end = start + max_duration
197
- if end > len(waveform):
198
- end = len(waveform)
199
- segment_part = waveform[start:end]
200
- segment_len = len(segment_part) / 16000
201
- if segment_len < 1:
202
- continue
203
- segments.append(segment_part)
204
- start = end
205
 
206
- for segment in segments:
207
- segment_tensor = torch.tensor(segment).to(device)
 
208
 
209
- # Fake a batch for the segment
210
- batch = segment_tensor.unsqueeze(0).to(device)
211
- rel_length = torch.tensor([1.0]).to(device)
 
 
212
 
213
- # Pass the segment through the ASR model
214
- segment_output = self.encode_batch_whisper(device, batch, rel_length)
215
- # segment_output = [" ".join(segment) for segment in segment_output]
216
- all_segments.append(segment_output)
217
-
218
- segments = ""
219
- for segment in all_segments:
220
- segment = segment[0]
221
- segments += segment + " "
222
- return [segments]
223
- else:
224
- waveform = torch.tensor(waveform).to(device)
225
- waveform = waveform.to(device)
226
- batch = waveform.unsqueeze(0)
227
- rel_length = torch.tensor([1.0]).to(device)
228
- outputs = self.encode_batch_whisper(device, batch, rel_length)
229
- return outputs
230
 
 
 
 
231
 
232
- def classify_file_whisper_mkd_streaming(self, waveform, device):
233
- # Load the audio file
234
- # waveform, sr = librosa.load(path, sr=16000)
 
235
 
236
- # Get audio length in seconds
237
- audio_length = len(waveform) / 16000
238
-
239
- if audio_length >= 30:
240
- # split audio every 30 seconds
241
- segments = []
242
- max_duration = 30 * 16000 # Maximum segment duration in samples (20 seconds)
243
- num_segments = int(np.ceil(len(waveform) / max_duration))
244
- start = 0
245
- for i in range(num_segments):
246
- end = start + max_duration
247
- if end > len(waveform):
248
- end = len(waveform)
249
- segment_part = waveform[start:end]
250
- segment_len = len(segment_part) / 16000
251
- if segment_len < 1:
252
- continue
253
- segments.append(segment_part)
254
- start = end
255
 
256
- for segment in segments:
257
  segment_tensor = torch.tensor(segment).to(device)
258
 
259
  # Fake a batch for the segment
260
  batch = segment_tensor.unsqueeze(0).to(device)
261
- rel_length = torch.tensor([1.0]).to(device)
262
 
263
  # Pass the segment through the ASR model
264
  segment_output = self.encode_batch_whisper(device, batch, rel_length)
265
- yield segment_output
 
266
  else:
267
  waveform = torch.tensor(waveform).to(device)
268
  waveform = waveform.to(device)
 
269
  batch = waveform.unsqueeze(0)
270
  rel_length = torch.tensor([1.0]).to(device)
271
- outputs = self.encode_batch_whisper(device, batch, rel_length)
272
- yield outputs
273
 
274
 
275
- def classify_file_whisper(self, waveform, pipe, device):
276
- # waveform, sr = librosa.load(path, sr=16000)
277
  transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
278
- return [transcription]
279
 
280
 
281
- def classify_file_mms(self, waveform, processor, model, device):
282
  # Load the audio file
283
- # waveform, sr = librosa.load(path, sr=16000)
284
 
285
  # Get audio length in seconds
286
- audio_length = len(waveform) / 16000
287
 
288
- if audio_length >= 30:
289
- # split audio every 20 seconds
 
 
 
290
  segments = []
291
- all_segments = []
292
- max_duration = 30 * 16000 # Maximum segment duration in samples (20 seconds)
293
- num_segments = int(np.ceil(len(waveform) / max_duration))
294
- start = 0
295
- for i in range(num_segments):
296
- end = start + max_duration
297
- if end > len(waveform):
298
- end = len(waveform)
299
  segment_part = waveform[start:end]
300
- segment_len = len(segment_part) / 16000
301
- if segment_len < 1:
302
- continue
303
- segments.append(segment_part)
304
- start = end
305
 
306
- for segment in segments:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  segment_tensor = torch.tensor(segment).to(device)
308
 
309
  # Pass the segment through the ASR model
310
  inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
311
- inputs['input_values'] = inputs['input_values']
312
  outputs = model(**inputs).logits
313
  ids = torch.argmax(outputs, dim=-1)[0]
314
  segment_output = processor.decode(ids)
315
- # segment_output = [" ".join(segment) for segment in segment_output]
316
- all_segments.append(segment_output)
317
-
318
- segments = ""
319
- for segment in all_segments:
320
- segments += segment + " "
321
- return [segments]
322
  else:
323
  waveform = torch.tensor(waveform).to(device)
324
  inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
325
- inputs['input_values'] = inputs['input_values']
326
  outputs = model(**inputs).logits
327
  ids = torch.argmax(outputs, dim=-1)[0]
328
  transcription = processor.decode(ids)
329
- return [transcription]
 
41
  # Forward encoder + decoder
42
  tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
43
  tokens = tokens.to(device)
44
+ enc_out, logits, _ = self.mods.whisper(wavs, tokens)
45
  log_probs = self.hparams.log_softmax(logits)
46
 
47
  hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
 
128
 
129
  def classify_file_w2v2(self, waveform, device):
130
  # Load the audio file
131
+ # path = "long_sample.wav"
132
  # waveform, sr = librosa.load(path, sr=16000)
133
 
134
+ # increase the volume if needed
135
+ # waveform = self.increase_volume(waveform)
136
+
137
  # Get audio length in seconds
138
+ sr = 16000
139
+ audio_length = len(waveform) / sr
140
 
141
+ if audio_length >= 20:
142
+ print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
143
+ # Detect non-silent segments
144
+
145
+ non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
146
+
147
  segments = []
148
+ current_segment = []
149
+ current_length = 0
150
+ max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
151
+
152
+
153
+ for interval in non_silent_intervals:
154
+ start, end = interval
 
155
  segment_part = waveform[start:end]
 
 
 
 
 
156
 
157
+ # If adding the next part exceeds max duration, store the segment and start a new one
158
+ if current_length + len(segment_part) > max_duration:
159
+ segments.append(np.concatenate(current_segment))
160
+ current_segment = []
161
+ current_length = 0
162
+
163
+ current_segment.append(segment_part)
164
+ current_length += len(segment_part)
165
+
166
+ # Append the last segment if it's not empty
167
+ if current_segment:
168
+ segments.append(np.concatenate(current_segment))
169
+
170
+ # Process each segment
171
+ outputs = []
172
+ for i, segment in enumerate(segments):
173
+ print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
174
+
175
+ # import soundfile as sf
176
+ # sf.write(f"outputs/segment_{i}.wav", segment, sr)
177
+
178
  segment_tensor = torch.tensor(segment).to(device)
179
 
180
  # Fake a batch for the segment
 
182
  rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
183
 
184
  # Pass the segment through the ASR model
185
+ result = " ".join(self.encode_batch_w2v2(device, batch, rel_length)[0])
186
+ outputs.append(result)
187
+ return outputs
 
 
 
 
 
 
188
  else:
189
  waveform = torch.tensor(waveform).to(device)
190
  waveform = waveform.to(device)
191
  # Fake a batch:
192
  batch = waveform.unsqueeze(0)
193
  rel_length = torch.tensor([1.0]).to(device)
194
+ outputs = " ".join(self.encode_batch_w2v2(device, batch, rel_length)[0])
195
+ return [outputs]
196
+
197
 
198
  def classify_file_whisper_mkd(self, waveform, device):
199
  # Load the audio file
200
+ # path = "long_sample.wav"
201
  # waveform, sr = librosa.load(path, sr=16000)
202
 
203
+ # increase the volume if needed
204
+ # waveform = self.increase_volume(waveform)
205
+
206
  # Get audio length in seconds
207
+ sr = 16000
208
+ audio_length = len(waveform) / sr
209
 
210
+ if audio_length >= 20:
211
+ print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
212
+ # Detect non-silent segments
213
+
214
+ non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
215
+
216
  segments = []
217
+ current_segment = []
218
+ current_length = 0
219
+ max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
 
 
 
 
 
 
 
 
 
 
 
220
 
221
+ for interval in non_silent_intervals:
222
+ start, end = interval
223
+ segment_part = waveform[start:end]
224
 
225
+ # If adding the next part exceeds max duration, store the segment and start a new one
226
+ if current_length + len(segment_part) > max_duration:
227
+ segments.append(np.concatenate(current_segment))
228
+ current_segment = []
229
+ current_length = 0
230
 
231
+ current_segment.append(segment_part)
232
+ current_length += len(segment_part)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ # Append the last segment if it's not empty
235
+ if current_segment:
236
+ segments.append(np.concatenate(current_segment))
237
 
238
+ # Process each segment
239
+ outputs = []
240
+ for i, segment in enumerate(segments):
241
+ print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
242
 
243
+ # import soundfile as sf
244
+ # sf.write(f"outputs/segment_{i}.wav", segment, sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
 
246
  segment_tensor = torch.tensor(segment).to(device)
247
 
248
  # Fake a batch for the segment
249
  batch = segment_tensor.unsqueeze(0).to(device)
250
+ rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
251
 
252
  # Pass the segment through the ASR model
253
  segment_output = self.encode_batch_whisper(device, batch, rel_length)
254
+ outputs.append(segment_output)
255
+ return outputs
256
  else:
257
  waveform = torch.tensor(waveform).to(device)
258
  waveform = waveform.to(device)
259
+ # Fake a batch:
260
  batch = waveform.unsqueeze(0)
261
  rel_length = torch.tensor([1.0]).to(device)
262
+ outputs.append(self.encode_batch_whisper(device, batch, rel_length))
263
+ return outputs
264
 
265
 
266
+ def classify_file_whisper(self, path, pipe, device):
267
+ waveform, sr = librosa.load(path, sr=16000)
268
  transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
269
+ return transcription
270
 
271
 
272
+ def classify_file_mms(self, path, processor, model, device):
273
  # Load the audio file
274
+ waveform, sr = librosa.load(path, sr=16000)
275
 
276
  # Get audio length in seconds
277
+ audio_length = len(waveform) / sr
278
 
279
+ if audio_length >= 20:
280
+ print(f"MMS Audio is too long ({audio_length:.2f} seconds), splitting into segments")
281
+ # Detect non-silent segments
282
+ non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
283
+
284
  segments = []
285
+ current_segment = []
286
+ current_length = 0
287
+ max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
288
+
289
+
290
+ for interval in non_silent_intervals:
291
+ start, end = interval
 
292
  segment_part = waveform[start:end]
 
 
 
 
 
293
 
294
+ # If adding the next part exceeds max duration, store the segment and start a new one
295
+ if current_length + len(segment_part) > max_duration:
296
+ segments.append(np.concatenate(current_segment))
297
+ current_segment = []
298
+ current_length = 0
299
+
300
+ current_segment.append(segment_part)
301
+ current_length += len(segment_part)
302
+
303
+ # Append the last segment if it's not empty
304
+ if current_segment:
305
+ segments.append(np.concatenate(current_segment))
306
+
307
+ # Process each segment
308
+ outputs = []
309
+ for i, segment in enumerate(segments):
310
+ print(f"MMS Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
311
+
312
  segment_tensor = torch.tensor(segment).to(device)
313
 
314
  # Pass the segment through the ASR model
315
  inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
 
316
  outputs = model(**inputs).logits
317
  ids = torch.argmax(outputs, dim=-1)[0]
318
  segment_output = processor.decode(ids)
319
+ yield segment_output
 
 
 
 
 
 
320
  else:
321
  waveform = torch.tensor(waveform).to(device)
322
  inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
 
323
  outputs = model(**inputs).logits
324
  ids = torch.argmax(outputs, dim=-1)[0]
325
  transcription = processor.decode(ids)
326
+ yield transcription