cstr commited on
Commit
4b50bd3
·
verified ·
1 Parent(s): e922c51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -90
app.py CHANGED
@@ -42,11 +42,18 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
42
  def download_audio(url, method_choice):
43
  parsed_url = urlparse(url)
44
  logging.info(f"Downloading audio from URL: {url} using method: {method_choice}")
45
- if parsed_url.netloc in ['www.youtube.com', 'youtu.be', 'youtube.com']:
46
- return download_youtube_audio(url, method_choice)
47
- else:
48
- return download_direct_audio(url, method_choice)
49
-
 
 
 
 
 
 
 
50
  def download_youtube_audio(url, method_choice):
51
  methods = {
52
  'yt-dlp': youtube_dl_method,
@@ -66,19 +73,24 @@ def download_youtube_audio(url, method_choice):
66
 
67
  def youtube_dl_method(url):
68
  logging.info("Using yt-dlp method")
69
- ydl_opts = {
70
- 'format': 'bestaudio/best',
71
- 'postprocessors': [{
72
- 'key': 'FFmpegExtractAudio',
73
- 'preferredcodec': 'mp3',
74
- 'preferredquality': '192',
75
- }],
76
- 'outtmpl': '%(id)s.%(ext)s',
77
- }
78
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
79
- info = ydl.extract_info(url, download=True)
80
- logging.info(f"Downloaded YouTube audio: {info['id']}.mp3")
81
- return f"{info['id']}.mp3"
 
 
 
 
 
82
 
83
  def pytube_method(url):
84
  logging.info("Using pytube method")
@@ -183,11 +195,11 @@ def trim_audio(audio_path, start_time, end_time):
183
 
184
  # Validate times
185
  if start_time < 0 or end_time < 0:
186
- raise ValueError("Start time and end time must be non-negative.")
187
  if start_time >= end_time:
188
- raise gr.Error("End time must be greater than start time.")
189
  if start_time > audio_duration:
190
- raise ValueError("Start time exceeds audio duration.")
191
 
192
  trimmed_audio = audio[start_time * 1000:end_time * 1000]
193
  trimmed_audio_path = tempfile.mktemp(suffix='.wav')
@@ -212,67 +224,40 @@ def get_model_options(pipeline_type):
212
  else:
213
  return []
214
 
 
 
215
  def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False):
216
  try:
 
 
 
 
 
 
 
 
 
 
 
217
  # Determine if input_source is a URL or file
218
  if isinstance(input_source, str):
219
  if input_source.startswith('http://') or input_source.startswith('https://'):
220
  audio_path = download_audio(input_source, download_method)
221
- # Handle potential errors during download
222
  if not audio_path or audio_path.startswith("Error"):
223
  yield f"Error: {audio_path}", "", None
224
  return
225
- else:
226
- # Assume input_source is an uploaded file object
 
 
 
227
  audio_path = input_source.name
228
  logging.info(f"Using uploaded audio file: {audio_path}")
229
-
230
- try:
231
- logging.info(f"Transcription parameters: pipeline_type={pipeline_type}, model_id={model_id}, dtype={dtype}, batch_size={batch_size}, download_method={download_method}")
232
- verbose_messages = f"Starting transcription with parameters:\nPipeline Type: {pipeline_type}\nModel ID: {model_id}\nData Type: {dtype}\nBatch Size: {batch_size}\nDownload Method: {download_method}\n"
233
-
234
- if verbose:
235
- yield verbose_messages, "", None
236
-
237
- if pipeline_type == "faster-batched":
238
- model = WhisperModel(model_id, device="auto", compute_type=dtype)
239
- pipeline = BatchedInferencePipeline(model=model)
240
- elif pipeline_type == "faster-sequenced":
241
- model = WhisperModel(model_id)
242
- pipeline = model.transcribe
243
- elif pipeline_type == "transformers":
244
- torch_dtype = torch.float16 if dtype == "float16" else torch.float32
245
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
246
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
247
- )
248
- model.to(device)
249
- processor = AutoProcessor.from_pretrained(model_id)
250
- pipeline = pipeline(
251
- "automatic-speech-recognition",
252
- model=model,
253
- tokenizer=processor.tokenizer,
254
- feature_extractor=processor.feature_extractor,
255
- chunk_length_s=30,
256
- batch_size=batch_size,
257
- return_timestamps=True,
258
- torch_dtype=torch_dtype,
259
- device=device,
260
- )
261
  else:
262
- raise ValueError("Invalid pipeline type")
263
-
264
- if isinstance(input_source, str) and (input_source.startswith('http://') or input_source.startswith('https://')):
265
- audio_path = download_audio(input_source, download_method)
266
- verbose_messages += f"Audio file downloaded: {audio_path}\n"
267
- if verbose:
268
- yield verbose_messages, "", None
269
-
270
- if not audio_path or audio_path.startswith("Error"):
271
- yield f"Error: {audio_path}", "", None
272
- return
273
- else:
274
- audio_path = input_source
275
 
 
276
  start_time = float(start_time) if start_time else None
277
  end_time = float(end_time) if end_time else None
278
 
@@ -283,11 +268,47 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
283
  if verbose:
284
  yield verbose_messages, "", None
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  start_time_perf = time.time()
287
- if pipeline_type in ["faster-batched", "faster-sequenced"]:
288
- segments, info = pipeline(audio_path, batch_size=batch_size)
 
 
289
  else:
290
- result = pipeline(audio_path)
291
  segments = result["chunks"]
292
  end_time_perf = time.time()
293
 
@@ -305,11 +326,10 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
305
  transcription = ""
306
 
307
  for segment in segments:
308
- transcription_segment = (
309
- f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}\n"
310
- if pipeline_type in ["faster-batched", "faster-sequenced"] else
311
- f"[{segment['timestamp'][0]:.2f}s -> {segment['timestamp'][1]:.2f}s] {segment['text']}\n"
312
- )
313
  transcription += transcription_segment
314
  if verbose:
315
  yield verbose_messages + metrics_output, transcription, None
@@ -322,23 +342,21 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
322
  yield f"An error occurred: {str(e)}", "", None
323
 
324
  finally:
325
- # Remove downloaded audio file
326
  if audio_path and os.path.exists(audio_path):
327
  os.remove(audio_path)
328
- # Remove trimmed audio file
329
  if 'trimmed_audio_path' in locals() and os.path.exists(trimmed_audio_path):
330
  os.remove(trimmed_audio_path)
331
- # Remove transcription file if needed
332
- if transcription_file and os.path.exists(transcription_file):
333
  os.remove(transcription_file)
334
-
335
 
336
  with gr.Blocks() as iface:
337
  gr.Markdown("# Multi-Pipeline Transcription")
338
  gr.Markdown("Transcribe audio using multiple pipelines and models.")
339
 
340
  with gr.Row():
341
- input_source = gr.File(label="Audio Source (Upload a file or enter a URL/YouTube URL)")
 
342
  pipeline_type = gr.Dropdown(
343
  choices=["faster-batched", "faster-sequenced", "transformers"],
344
  label="Pipeline Type",
@@ -375,7 +393,6 @@ with gr.Blocks() as iface:
375
  try:
376
  model_choices = get_model_options(pipeline_type)
377
  logging.info(f"Model choices for {pipeline_type}: {model_choices}")
378
-
379
  if model_choices:
380
  return gr.update(choices=model_choices, value=model_choices[0], visible=True)
381
  else:
@@ -383,9 +400,9 @@ with gr.Blocks() as iface:
383
  except Exception as e:
384
  logging.error(f"Error in update_model_dropdown: {str(e)}")
385
  return gr.update(choices=["Error"], value="Error", visible=True)
386
-
387
- #pipeline_type.change(update_model_dropdown, inputs=pipeline_type, outputs=model_id)
388
- pipeline_type.change(update_model_dropdown, inputs=[pipeline_type], outputs=model_id)
389
 
390
  def transcribe_with_progress(*args):
391
  for result in transcribe_audio(*args):
@@ -399,9 +416,9 @@ with gr.Blocks() as iface:
399
 
400
  gr.Examples(
401
  examples=[
402
- ["https://www.youtube.com/watch?v=daQ_hqA6HDo", "faster-batched", "cstr/whisper-large-v3-turbo-int8_float32", "int8", 16, "yt-dlp", None, None, True],
403
- ["https://mcdn.podbean.com/mf/web/dir5wty678b6g4vg/HoP_453_-_The_Price_is_Right_-_Law_and_Economics_in_the_Second_Scholastic5yxzh.mp3", "faster-sequenced", "deepdml/faster-whisper-large-v3-turbo-ct2", "float16", 1, "ffmpeg", 0, 300, True],
404
- [None, "transformers", "openai/whisper-large-v3", "float16", 16, "yt-dlp", 60, 180, True]
405
  ],
406
  inputs=[input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time, end_time, verbose],
407
  )
 
42
  def download_audio(url, method_choice):
43
  parsed_url = urlparse(url)
44
  logging.info(f"Downloading audio from URL: {url} using method: {method_choice}")
45
+ try:
46
+ if parsed_url.netloc in ['www.youtube.com', 'youtu.be', 'youtube.com']:
47
+ audio_file = download_youtube_audio(url, method_choice)
48
+ else:
49
+ audio_file = download_direct_audio(url, method_choice)
50
+ if not audio_file or not os.path.exists(audio_file):
51
+ raise Exception(f"Failed to download audio from {url}")
52
+ return audio_file
53
+ except Exception as e:
54
+ logging.error(f"Error downloading audio: {str(e)}")
55
+ return f"Error: {str(e)}"
56
+
57
  def download_youtube_audio(url, method_choice):
58
  methods = {
59
  'yt-dlp': youtube_dl_method,
 
73
 
74
  def youtube_dl_method(url):
75
  logging.info("Using yt-dlp method")
76
+ try:
77
+ ydl_opts = {
78
+ 'format': 'bestaudio/best',
79
+ 'postprocessors': [{
80
+ 'key': 'FFmpegExtractAudio',
81
+ 'preferredcodec': 'mp3',
82
+ 'preferredquality': '192',
83
+ }],
84
+ 'outtmpl': '%(id)s.%(ext)s',
85
+ }
86
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
87
+ info = ydl.extract_info(url, download=True)
88
+ output_file = f"{info['id']}.mp3"
89
+ logging.info(f"Downloaded YouTube audio: {output_file}")
90
+ return output_file
91
+ except Exception as e:
92
+ logging.error(f"Error in youtube_dl_method: {str(e)}")
93
+ return None
94
 
95
  def pytube_method(url):
96
  logging.info("Using pytube method")
 
195
 
196
  # Validate times
197
  if start_time < 0 or end_time < 0:
198
+ raise gr.Error("Start time and end time must be non-negative.")
199
  if start_time >= end_time:
200
+ raise gr.Error("End time must be greater than start time.")
201
  if start_time > audio_duration:
202
+ raise gr.Error("Start time exceeds audio duration.")
203
 
204
  trimmed_audio = audio[start_time * 1000:end_time * 1000]
205
  trimmed_audio_path = tempfile.mktemp(suffix='.wav')
 
224
  else:
225
  return []
226
 
227
+ loaded_models = {}
228
+
229
  def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False):
230
  try:
231
+ if verbose:
232
+ logging.getLogger().setLevel(logging.INFO)
233
+ else:
234
+ logging.getLogger().setLevel(logging.WARNING)
235
+
236
+ logging.info(f"Transcription parameters: pipeline_type={pipeline_type}, model_id={model_id}, dtype={dtype}, batch_size={batch_size}, download_method={download_method}")
237
+ verbose_messages = f"Starting transcription with parameters:\nPipeline Type: {pipeline_type}\nModel ID: {model_id}\nData Type: {dtype}\nBatch Size: {batch_size}\nDownload Method: {download_method}\n"
238
+
239
+ if verbose:
240
+ yield verbose_messages, "", None
241
+
242
  # Determine if input_source is a URL or file
243
  if isinstance(input_source, str):
244
  if input_source.startswith('http://') or input_source.startswith('https://'):
245
  audio_path = download_audio(input_source, download_method)
 
246
  if not audio_path or audio_path.startswith("Error"):
247
  yield f"Error: {audio_path}", "", None
248
  return
249
+ else:
250
+ # Assume it's a local file path
251
+ audio_path = input_source
252
+ elif input_source is not None:
253
+ # Uploaded file object
254
  audio_path = input_source.name
255
  logging.info(f"Using uploaded audio file: {audio_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  else:
257
+ yield "No audio source provided.", "", None
258
+ return
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ # Convert start_time and end_time to float or None
261
  start_time = float(start_time) if start_time else None
262
  end_time = float(end_time) if end_time else None
263
 
 
268
  if verbose:
269
  yield verbose_messages, "", None
270
 
271
+ # Model caching
272
+ model_key = (pipeline_type, model_id, dtype)
273
+ if model_key in loaded_models:
274
+ model_or_pipeline = loaded_models[model_key]
275
+ logging.info("Loaded model from cache")
276
+ else:
277
+ if pipeline_type == "faster-batched":
278
+ model = WhisperModel(model_id, device=device, compute_type=dtype)
279
+ pipeline = BatchedInferencePipeline(model=model)
280
+ elif pipeline_type == "faster-sequenced":
281
+ model = WhisperModel(model_id, device=device, compute_type=dtype)
282
+ pipeline = model.transcribe
283
+ elif pipeline_type == "transformers":
284
+ torch_dtype = torch.float16 if dtype == "float16" else torch.float32
285
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
286
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
287
+ )
288
+ model.to(device)
289
+ processor = AutoProcessor.from_pretrained(model_id)
290
+ pipeline = pipeline(
291
+ "automatic-speech-recognition",
292
+ model=model,
293
+ tokenizer=processor.tokenizer,
294
+ feature_extractor=processor.feature_extractor,
295
+ chunk_length_s=30,
296
+ batch_size=batch_size,
297
+ return_timestamps=True,
298
+ torch_dtype=torch_dtype,
299
+ device=device,
300
+ )
301
+ else:
302
+ raise ValueError("Invalid pipeline type")
303
+ loaded_models[model_key] = model_or_pipeline # Cache the model
304
+
305
  start_time_perf = time.time()
306
+ if pipeline_type == "faster-batched":
307
+ segments, info = model_or_pipeline.transcribe(audio_path, batch_size=batch_size)
308
+ elif pipeline_type == "faster-sequenced":
309
+ segments, info = model_or_pipeline.transcribe(audio_path)
310
  else:
311
+ result = model_or_pipeline(audio_path)
312
  segments = result["chunks"]
313
  end_time_perf = time.time()
314
 
 
326
  transcription = ""
327
 
328
  for segment in segments:
329
+ if pipeline_type in ["faster-batched", "faster-sequenced"]:
330
+ transcription_segment = f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}\n"
331
+ else:
332
+ transcription_segment = f"[{segment['timestamp'][0]:.2f}s -> {segment['timestamp'][1]:.2f}s] {segment['text']}\n"
 
333
  transcription += transcription_segment
334
  if verbose:
335
  yield verbose_messages + metrics_output, transcription, None
 
342
  yield f"An error occurred: {str(e)}", "", None
343
 
344
  finally:
345
+ # Clean up temporary files
346
  if audio_path and os.path.exists(audio_path):
347
  os.remove(audio_path)
 
348
  if 'trimmed_audio_path' in locals() and os.path.exists(trimmed_audio_path):
349
  os.remove(trimmed_audio_path)
350
+ if 'transcription_file' in locals() and os.path.exists(transcription_file):
 
351
  os.remove(transcription_file)
 
352
 
353
  with gr.Blocks() as iface:
354
  gr.Markdown("# Multi-Pipeline Transcription")
355
  gr.Markdown("Transcribe audio using multiple pipelines and models.")
356
 
357
  with gr.Row():
358
+ #input_source = gr.File(label="Audio Source (Upload a file or enter a URL/YouTube URL)")
359
+ input_source = gr.Textbox(label="Audio Source (Upload a file or enter a URL/YouTube URL)")
360
  pipeline_type = gr.Dropdown(
361
  choices=["faster-batched", "faster-sequenced", "transformers"],
362
  label="Pipeline Type",
 
393
  try:
394
  model_choices = get_model_options(pipeline_type)
395
  logging.info(f"Model choices for {pipeline_type}: {model_choices}")
 
396
  if model_choices:
397
  return gr.update(choices=model_choices, value=model_choices[0], visible=True)
398
  else:
 
400
  except Exception as e:
401
  logging.error(f"Error in update_model_dropdown: {str(e)}")
402
  return gr.update(choices=["Error"], value="Error", visible=True)
403
+
404
+ # event handler for pipeline_type change
405
+ pipeline_type.change(update_model_dropdown, inputs=[pipeline_type], outputs=[model_id])
406
 
407
  def transcribe_with_progress(*args):
408
  for result in transcribe_audio(*args):
 
416
 
417
  gr.Examples(
418
  examples=[
419
+ ["https://www.youtube.com/watch?v=daQ_hqA6HDo", "faster-batched", "cstr/whisper-large-v3-turbo-int8_float32", "int8", 16, "yt-dlp", None, None, True],
420
+ ["https://mcdn.podbean.com/mf/web/dir5wty678b6g4vg/HoP_453_-_The_Price_is_Right_-_Law_and_Economics_in_the_Second_Scholastic5yxzh.mp3", "faster-sequenced", "deepdml/faster-whisper-large-v3-turbo-ct2", "float16", 1, "ffmpeg", 0, 300, True],
421
+ ["path/to/local/audio.mp3", "transformers", "openai/whisper-large-v3", "float16", 16, "yt-dlp", 60, 180, True]
422
  ],
423
  inputs=[input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time, end_time, verbose],
424
  )