ssolito commited on
Commit
d262e76
·
verified ·
1 Parent(s): bd4cc90

Update whisper_cs.py (#41)

Browse files

- Update whisper_cs.py (b7e436130fa58c70354776f4df83c4b894aef18b)

Files changed (1) hide show
  1. whisper_cs.py +94 -2
whisper_cs.py CHANGED
@@ -178,8 +178,99 @@ def transcribe_audio(model, audio_path: str) -> Dict:
178
  'error': str(e),
179
  'success': False
180
  }
181
-
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def generate(audio_path, use_v2_fast):
184
 
185
  if DEBUG_MODE: print(f"Entering generate function...")
@@ -270,4 +361,5 @@ def generate(audio_path, use_v2_fast):
270
 
271
  if DEBUG_MODE: print(f"Exiting generate function...")
272
 
273
- return clean_output
 
 
178
  'error': str(e),
179
  'success': False
180
  }
 
181
 
182
+
183
+
184
+ def generate(audio_path, use_v2_fast):
185
+ global faster_model
186
+ if DEBUG_MODE: print(f"Entering generate function...")
187
+ if DEBUG_MODE: print(f"use_v2_fast: {use_v2_fast}")
188
+
189
+ if use_v2_fast and torch.cuda.is_available():
190
+ try:
191
+ faster_model.to("cuda")
192
+ print("[INFO] Moved faster_model to CUDA")
193
+ except Exception as e:
194
+ print(f"[WARNING] Could not move model to CUDA: {e}")
195
+
196
+ if use_v2_fast:
197
+ split_stereo_channels(audio_path)
198
+ left_channel_path = "temp_mono_speaker2.wav"
199
+ right_channel_path = "temp_mono_speaker1.wav"
200
+
201
+ left_waveform, _ = format_audio(left_channel_path)
202
+ right_waveform, _ = format_audio(right_channel_path)
203
+
204
+ left_waveform = left_waveform.numpy().astype("float32")
205
+ right_waveform = right_waveform.numpy().astype("float32")
206
+
207
+ left_result, _ = faster_model.transcribe(left_waveform, beam_size=5, task="transcribe")
208
+ right_result, _ = faster_model.transcribe(right_waveform, beam_size=5, task="transcribe")
209
+
210
+ left_result = list(left_result)
211
+ right_result = list(right_result)
212
+
213
+ def get_faster_segments(segments, speaker_label):
214
+ return [
215
+ (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip()))
216
+ for seg in segments if seg.text
217
+ ]
218
+
219
+ left_segs = get_faster_segments(left_result, "Speaker 1")
220
+ right_segs = get_faster_segments(right_result, "Speaker 2")
221
+
222
+ merged_transcript = sorted(
223
+ left_segs + right_segs,
224
+ key=lambda x: float(x[0]) if x[0] is not None else float("inf")
225
+ )
226
+
227
+ clean_output = ""
228
+ for start, end, speaker, text in merged_transcript:
229
+ clean_output += f"[{speaker}]: {text}\n"
230
+
231
+ if DEBUG_MODE: print(f"clean_output: {clean_output}")
232
+
233
+ else:
234
+ model = load_whisper_model(MODEL_PATH_V2)
235
+ split_stereo_channels(audio_path)
236
+ left_channel_path = "temp_mono_speaker2.wav"
237
+ right_channel_path = "temp_mono_speaker1.wav"
238
+
239
+ left_waveform, _ = format_audio(left_channel_path)
240
+ right_waveform, _ = format_audio(right_channel_path)
241
+
242
+ left_result = transcribe_audio(model, left_waveform)
243
+ right_result = transcribe_audio(model, right_waveform)
244
+
245
+ def get_segments(result, speaker_label):
246
+ segments = result.get("segments", [])
247
+ if not segments:
248
+ return []
249
+ return [
250
+ (seg.get("start", 0.0), seg.get("end", 0.0), speaker_label,
251
+ post_process_transcription(seg.get("text", "").strip()))
252
+ for seg in segments if seg.get("text")
253
+ ]
254
+
255
+ left_segs = get_segments(left_result, "Speaker 1")
256
+ right_segs = get_segments(right_result, "Speaker 2")
257
+
258
+ merged_transcript = sorted(
259
+ left_segs + right_segs,
260
+ key=lambda x: float(x[0]) if x[0] is not None else float("inf")
261
+ )
262
+
263
+ clean_output = ""
264
+ for start, end, speaker, text in merged_transcript:
265
+ clean_output += f"[{speaker}]: {text}\n"
266
+
267
+ cleanup_temp_files("temp_mono_speaker1.wav", "temp_mono_speaker2.wav")
268
+
269
+ if DEBUG_MODE: print(f"Exiting generate function...")
270
+ return clean_output.strip()
271
+
272
+
273
+ '''
274
  def generate(audio_path, use_v2_fast):
275
 
276
  if DEBUG_MODE: print(f"Entering generate function...")
 
361
 
362
  if DEBUG_MODE: print(f"Exiting generate function...")
363
 
364
+ return clean_output
365
+ '''