yentinglin commited on
Commit
7608be3
·
verified ·
1 Parent(s): 563150c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -36
app.py CHANGED
@@ -11,6 +11,46 @@ from pyannote.audio import Pipeline
11
  from huggingface_hub import HfApi
12
  from torchaudio import functional as F # For resampling and audio processing
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(levelname)s - %(message)s')
16
  logger = logging.getLogger(__name__)
@@ -27,14 +67,14 @@ logger = logging.getLogger(__name__)
27
  HF_TOKEN = os.getenv("HF_TOKEN")
28
 
29
  # Model names
30
- ASR_MODEL = "openai/whisper-small" # Smaller, faster Whisper model for demo
31
  DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
32
  # Speculative decoding (assistant model) is explicitly excluded as per requirements.
33
 
34
  # --- Inference Configuration (Pydantic Model for validation) ---
35
  class InferenceConfig(BaseModel):
36
  task: Literal["transcribe", "translate"] = "transcribe"
37
- batch_size: int = 24
38
  chunk_length_s: int = 30
39
  language: Optional[str] = None
40
  num_speakers: Optional[int] = None
@@ -210,29 +250,46 @@ def post_process_segments_and_transcripts(combined_diarization_segments: list, a
210
  diar_end = diar_segment["segment"]["end"]
211
  speaker = diar_segment["speaker"]
212
 
213
- # Find the index in `current_asr_end_timestamps` whose value is closest to `diar_end`.
214
- # This `upto_idx_relative` determines how many ASR chunks from `current_asr_chunks`
215
- # will be associated with the current `diar_segment`.
 
 
 
216
  upto_idx_relative = np.argmin(np.abs(current_asr_end_timestamps - diar_end))
217
 
218
- # Select the ASR chunks up to and including this `upto_idx_relative`.
219
  chunks_for_this_diar_segment = current_asr_chunks[:upto_idx_relative + 1]
220
 
221
  if not chunks_for_this_diar_segment:
222
- continue # No ASR chunks found for this diarization segment, skip
 
223
 
224
- # Combine the text from the selected ASR chunks.
225
- combined_text = "".join([chunk["text"] for chunk in chunks_for_this_diar_segment]).strip()
 
 
 
 
 
 
 
 
 
 
226
 
227
- # Determine the start and end timestamp for the combined ASR text.
228
- # This will be the min start and max end of the involved ASR chunks.
229
- asr_min_start = min(chunk["timestamp"][0] for chunk in chunks_for_this_diar_segment if chunk["timestamp"][0] is not None)
230
- asr_max_end = max(chunk["timestamp"][1] for chunk in chunks_for_this_diar_segment if chunk["timestamp"][1] is not None)
 
 
 
 
 
231
 
232
- # Final timestamp for the output segment should be clamped by the diarization segment's boundaries
233
- # to ensure it doesn't extend beyond what the diarizer indicated.
234
- final_segment_start = max(diar_start, asr_min_start)
235
- final_segment_end = min(diar_end, asr_max_end)
236
 
237
  final_segmented_transcript.append(
238
  {
@@ -242,12 +299,13 @@ def post_process_segments_and_transcripts(combined_diarization_segments: list, a
242
  }
243
  )
244
 
245
- # Remove the processed ASR chunks from the lists for the next iteration.
246
  current_asr_chunks = current_asr_chunks[upto_idx_relative + 1:]
247
  current_asr_end_timestamps = current_asr_end_timestamps[upto_idx_relative + 1:]
248
 
249
  return final_segmented_transcript
250
 
 
251
  def diarize_and_align_transcript(diarization_pipeline: Pipeline, original_sampling_rate: int,
252
  audio_numpy_array: np.ndarray, parameters: InferenceConfig, asr_outputs: dict) -> list:
253
  """
@@ -303,12 +361,12 @@ def predict_audio(
303
  - status_message: A message indicating success or failure.
304
  """
305
  if audio_file_tuple is None:
306
- return "", "", "Please upload an audio file."
307
 
308
  sampling_rate, audio_numpy_array = audio_file_tuple
309
 
310
  if audio_numpy_array is None or audio_numpy_array.size == 0:
311
- return "", "", "Audio file is empty. Please upload a valid audio."
312
 
313
  # Ensure audio_numpy_array is float32 as expected by transformers pipeline
314
  if audio_numpy_array.dtype != np.float32:
@@ -318,19 +376,34 @@ def predict_audio(
318
  if len(audio_numpy_array.shape) > 1:
319
  audio_numpy_array = audio_numpy_array[:, 0]
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  # Create an InferenceConfig object from Gradio inputs for internal validation and use.
322
  try:
323
  parameters = InferenceConfig(
324
  batch_size=batch_size,
325
  chunk_length_s=chunk_length_s,
326
  language=language if language != "Auto-detect" else None, # Convert "Auto-detect" to None for model
327
- num_speakers=num_speakers,
328
- min_speakers=min_speakers,
329
- max_speakers=max_speakers,
330
  )
331
  except Exception as e:
332
  logger.error(f"Error validating parameters: {e}")
333
- return "", "", f"Error validating input parameters: {e}"
334
 
335
  logger.info(f"Inference parameters: {parameters.model_dump_json()}")
336
  logger.info(f"Audio sampling rate: {sampling_rate} Hz, Audio shape: {audio_numpy_array.shape}")
@@ -339,7 +412,14 @@ def predict_audio(
339
  diarization_pipeline = models.get("diarization_pipeline")
340
 
341
  if not asr_pipeline:
342
- return "", "", "ASR model not loaded. Please restart the application."
 
 
 
 
 
 
 
343
 
344
  # Prepare ASR generation arguments
345
  generate_kwargs = {
@@ -357,12 +437,12 @@ def predict_audio(
357
  batch_size=parameters.batch_size,
358
  generate_kwargs=generate_kwargs,
359
  return_timestamps=True,
360
- #sampling_rate=sampling_rate # Pass original sampling rate to pipeline
361
  )
362
  logger.info("ASR inference completed.")
363
  except Exception as e:
364
  logger.error(f"ASR inference error: {str(e)}")
365
- return "", "", f"ASR inference error: {str(e)}"
366
 
367
  final_transcript_data = []
368
  status_message = ""
@@ -426,12 +506,12 @@ demo = gr.Interface(
426
  fn=predict_audio,
427
  inputs=[
428
  gr.Audio(type="numpy", label="Upload Audio File (WAV, MP3, FLAC, etc.)"),
429
- gr.Slider(minimum=1, maximum=32, value=24, step=1, label="ASR Batch Size"),
430
- gr.Slider(minimum=1, maximum=60, value=30, step=1, label="ASR Chunk Length (seconds)"),
431
- gr.Dropdown(WHISPER_LANGUAGES, value="Auto-detect", label="ASR Language"),
432
- gr.Number(label="Diarization: Number of Speakers (optional)", value=None, precision=0, info="Expected total number of speakers."),
433
- gr.Number(label="Diarization: Min Speakers (optional)", value=None, precision=0, info="Minimum number of speakers to detect."),
434
- gr.Number(label="Diarization: Max Speakers (optional)", value=None, precision=0, info="Maximum number of speakers to detect.")
435
  ],
436
  outputs=[
437
  gr.Textbox(label="Diarized Transcript", lines=10, interactive=False),
@@ -447,13 +527,15 @@ demo = gr.Interface(
447
  "<br><b>Note:</b> For long audios or high concurrent usage, consider using a GPU and models like `whisper-large-v3`."
448
  ),
449
  allow_flagging="never", # Disable Gradio flagging feature
450
- # Example audio path assumes you are running from the cloned repository root.
451
- # If not, download a small WAV file (e.g., from Common Voice) and update this path.
452
  examples=[
 
 
 
453
  [os.path.join(os.path.dirname(__file__), "model-server", "app", "tests", "polyai-minds14-0.wav"), 24, 30, "Auto-detect", None, None, None]
454
  ],
455
- cache_examples=False,
456
  )
457
 
458
  if __name__ == "__main__":
 
459
  demo.launch()
 
11
  from huggingface_hub import HfApi
12
  from torchaudio import functional as F # For resampling and audio processing
13
 
14
+ # To run this Gradio demo, first ensure you have Python 3.9+ installed.
15
+ # Then, create a virtual environment and install the required packages.
16
+ #
17
+ # 1. Create and activate a virtual environment (recommended):
18
+ # python3 -m venv venv
19
+ # source venv/bin/activate # On Linux/macOS
20
+ # venv\Scripts\activate # On Windows
21
+ #
22
+ # 2. Install the necessary packages:
23
+ # pip install gradio==4.20.1 \
24
+ # torch==2.2.1 \
25
+ # torchaudio==2.2.1 \
26
+ # transformers==4.38.2 \
27
+ # pyannote-audio==3.1.1 \
28
+ # numpy==1.26.4 \
29
+ # fastapi==0.110.0 \
30
+ # uvicorn==0.27.1 \
31
+ # python-multipart==0.0.9 \
32
+ # pydantic==2.6.1 \
33
+ # soundfile==0.12.1 # Required by torchaudio and pyannote for certain audio formats
34
+ #
35
+ # # If you have a CUDA-compatible GPU, install the CUDA version of PyTorch instead:
36
+ # # For CUDA 12.1 (adjust 'cu121' to your CUDA version, e.g., 'cu118' for CUDA 11.8):
37
+ # # pip install torch==2.2.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
38
+ #
39
+ # 3. Set your Hugging Face Token (required for pyannote/speaker-diarization-3.1).
40
+ # You must accept the user conditions on the model page: https://huggingface.co/pyannote/speaker-diarization-3.1
41
+ # export HF_TOKEN="hf_YOUR_TOKEN_HERE"
42
+ # # Or directly in the script (not recommended for security):
43
+ # # HF_TOKEN = "hf_YOUR_TOKEN_HERE"
44
+ #
45
+ # 4. Save this file as, for example, `app.py`.
46
+ #
47
+ # 5. Run the application using uvicorn (as requested):
48
+ # uvicorn app:demo --host 0.0.0.0 --port 7860
49
+ #
50
+ # # If you just want to run it via Python script (Gradio's default, without uvicorn directly):
51
+ # # python app.py
52
+
53
+
54
  # Set up logging
55
  logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(levelname)s - %(message)s')
56
  logger = logging.getLogger(__name__)
 
67
  HF_TOKEN = os.getenv("HF_TOKEN")
68
 
69
  # Model names
70
+ ASR_MODEL = "openai/whisper-large-v3-turbo" # Smaller, faster Whisper model for demo
71
  DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
72
  # Speculative decoding (assistant model) is explicitly excluded as per requirements.
73
 
74
  # --- Inference Configuration (Pydantic Model for validation) ---
75
  class InferenceConfig(BaseModel):
76
  task: Literal["transcribe", "translate"] = "transcribe"
77
+ batch_size: int = 1
78
  chunk_length_s: int = 30
79
  language: Optional[str] = None
80
  num_speakers: Optional[int] = None
 
250
  diar_end = diar_segment["segment"]["end"]
251
  speaker = diar_segment["speaker"]
252
 
253
+ # Find the index of the ASR chunk whose end timestamp is closest to diar_end
254
+ # Ensure argmin operates on a non-empty array
255
+ if current_asr_end_timestamps.size == 0:
256
+ logger.warning("No ASR end timestamps left to align with diarization segment. Breaking alignment.")
257
+ break # No more ASR chunks to align
258
+
259
  upto_idx_relative = np.argmin(np.abs(current_asr_end_timestamps - diar_end))
260
 
 
261
  chunks_for_this_diar_segment = current_asr_chunks[:upto_idx_relative + 1]
262
 
263
  if not chunks_for_this_diar_segment:
264
+ logger.warning(f"No ASR chunks selected for diarization segment [{diar_start:.2f}-{diar_end:.2f}] {speaker}. Skipping.")
265
+ continue
266
 
267
+ # Initialize with extreme values to find min/max correctly, handling None timestamps
268
+ asr_min_start_val = float('inf')
269
+ asr_max_end_val = float('-inf')
270
+
271
+ all_text = []
272
+
273
+ for chunk in chunks_for_this_diar_segment:
274
+ all_text.append(chunk["text"])
275
+ if chunk["timestamp"] and chunk["timestamp"][0] is not None:
276
+ asr_min_start_val = min(asr_min_start_val, chunk["timestamp"][0])
277
+ if chunk["timestamp"] and chunk["timestamp"][1] is not None:
278
+ asr_max_end_val = max(asr_max_end_val, chunk["timestamp"][1])
279
 
280
+ combined_text = "".join(all_text).strip()
281
+
282
+ # If no valid timestamps were found in the selected ASR chunks, fall back to diarization segment's bounds
283
+ if asr_min_start_val == float('inf'):
284
+ logger.warning(f"No valid start timestamps in ASR chunks for segment [{diar_start:.2f}-{diar_end:.2f}] {speaker}. Using diarization start.")
285
+ asr_min_start_val = diar_start
286
+ if asr_max_end_val == float('-inf'):
287
+ logger.warning(f"No valid end timestamps in ASR chunks for segment [{diar_start:.2f}-{diar_end:.2f}] {speaker}. Using diarization end.")
288
+ asr_max_end_val = diar_end
289
 
290
+ # Ensure final timestamp range makes sense and is clamped by diarization segment
291
+ final_segment_start = max(diar_start, asr_min_start_val)
292
+ final_segment_end = min(diar_end, asr_max_end_val)
 
293
 
294
  final_segmented_transcript.append(
295
  {
 
299
  }
300
  )
301
 
302
+ # Crop the transcripts and timestamp lists according to the latest timestamp
303
  current_asr_chunks = current_asr_chunks[upto_idx_relative + 1:]
304
  current_asr_end_timestamps = current_asr_end_timestamps[upto_idx_relative + 1:]
305
 
306
  return final_segmented_transcript
307
 
308
+
309
  def diarize_and_align_transcript(diarization_pipeline: Pipeline, original_sampling_rate: int,
310
  audio_numpy_array: np.ndarray, parameters: InferenceConfig, asr_outputs: dict) -> list:
311
  """
 
361
  - status_message: A message indicating success or failure.
362
  """
363
  if audio_file_tuple is None:
364
+ return "", "", gr.Warning("Please upload an audio file.")
365
 
366
  sampling_rate, audio_numpy_array = audio_file_tuple
367
 
368
  if audio_numpy_array is None or audio_numpy_array.size == 0:
369
+ return "", "", gr.Warning("Audio file is empty. Please upload a valid audio.")
370
 
371
  # Ensure audio_numpy_array is float32 as expected by transformers pipeline
372
  if audio_numpy_array.dtype != np.float32:
 
376
  if len(audio_numpy_array.shape) > 1:
377
  audio_numpy_array = audio_numpy_array[:, 0]
378
 
379
+ # Process speaker parameters: convert 0 or negative values to None for pyannote compatibility
380
+ processed_num_speakers = num_speakers if num_speakers is not None and num_speakers > 0 else None
381
+ processed_min_speakers = min_speakers if min_speakers is not None and min_speakers > 0 else None
382
+ processed_max_speakers = max_speakers if max_speakers is not None and max_speakers > 0 else None
383
+
384
+ # Validation logic for min/max speakers
385
+ if processed_min_speakers is not None and processed_max_speakers is not None and processed_min_speakers > processed_max_speakers:
386
+ return "", "", gr.Warning("Diarization: Min Speakers cannot be greater than Max Speakers.")
387
+ if processed_num_speakers is not None:
388
+ if processed_min_speakers is not None and processed_num_speakers < processed_min_speakers:
389
+ return "", "", gr.Warning("Diarization: Number of Speakers cannot be less than Min Speakers.")
390
+ if processed_max_speakers is not None and processed_num_speakers > processed_max_speakers:
391
+ return "", "", gr.Warning("Diarization: Number of Speakers cannot be greater than Max Speakers.")
392
+
393
+
394
  # Create an InferenceConfig object from Gradio inputs for internal validation and use.
395
  try:
396
  parameters = InferenceConfig(
397
  batch_size=batch_size,
398
  chunk_length_s=chunk_length_s,
399
  language=language if language != "Auto-detect" else None, # Convert "Auto-detect" to None for model
400
+ num_speakers=processed_num_speakers,
401
+ min_speakers=processed_min_speakers,
402
+ max_speakers=processed_max_speakers,
403
  )
404
  except Exception as e:
405
  logger.error(f"Error validating parameters: {e}")
406
+ return "", "", gr.Error(f"Error validating input parameters: {e}") # Use gr.Error for critical validation failures
407
 
408
  logger.info(f"Inference parameters: {parameters.model_dump_json()}")
409
  logger.info(f"Audio sampling rate: {sampling_rate} Hz, Audio shape: {audio_numpy_array.shape}")
 
412
  diarization_pipeline = models.get("diarization_pipeline")
413
 
414
  if not asr_pipeline:
415
+ return "", "", gr.Error("ASR model not loaded. Please restart the application.")
416
+
417
+ # ASR language and batch size conflict warning/error
418
+ if parameters.language is None and parameters.batch_size > 1:
419
+ return "", "", gr.Warning(
420
+ "ASR: 'Auto-detect' language is not supported with batch size > 1. "
421
+ "Please select a specific language or set batch size to 1."
422
+ )
423
 
424
  # Prepare ASR generation arguments
425
  generate_kwargs = {
 
437
  batch_size=parameters.batch_size,
438
  generate_kwargs=generate_kwargs,
439
  return_timestamps=True,
440
+ # sampling_rate=sampling_rate # Pass original sampling rate to pipeline
441
  )
442
  logger.info("ASR inference completed.")
443
  except Exception as e:
444
  logger.error(f"ASR inference error: {str(e)}")
445
+ return "", "", gr.Error(f"ASR inference error: {str(e)}")
446
 
447
  final_transcript_data = []
448
  status_message = ""
 
506
  fn=predict_audio,
507
  inputs=[
508
  gr.Audio(type="numpy", label="Upload Audio File (WAV, MP3, FLAC, etc.)"),
509
+ gr.Slider(minimum=1, maximum=32, value=1, step=1, label="ASR Batch Size"),
510
+ gr.Slider(minimum=1, maximum=30, value=30, step=1, label="ASR Chunk Length (seconds)"),
511
+ gr.Dropdown(WHISPER_LANGUAGES, value="Chinese", label="ASR Language"),
512
+ gr.Number(label="Diarization: Number of Speakers (optional)", value=None, precision=0, info="Expected total number of speakers (positive integer, or leave empty for auto-detect)."),
513
+ gr.Number(label="Diarization: Min Speakers (optional)", value=None, precision=0, info="Minimum number of speakers to detect (positive integer, or leave empty for auto-detect)."),
514
+ gr.Number(label="Diarization: Max Speakers (optional)", value=None, precision=0, info="Maximum number of speakers to detect (positive integer, or leave empty for auto-detect).")
515
  ],
516
  outputs=[
517
  gr.Textbox(label="Diarized Transcript", lines=10, interactive=False),
 
527
  "<br><b>Note:</b> For long audios or high concurrent usage, consider using a GPU and models like `whisper-large-v3`."
528
  ),
529
  allow_flagging="never", # Disable Gradio flagging feature
 
 
530
  examples=[
531
+ # Adjust this path if the `model-server/app/tests/` directory is not alongside your `app.py`
532
+ # For example, if app.py is in the root, and the audio is in a tests/ subdirectory,
533
+ # you might use: ["tests/polyai-minds14-0.wav", 24, 30, "Auto-detect", None, None, None]
534
  [os.path.join(os.path.dirname(__file__), "model-server", "app", "tests", "polyai-minds14-0.wav"), 24, 30, "Auto-detect", None, None, None]
535
  ],
536
+ cache_examples=False # Disable caching of examples to prevent InvalidPathError
537
  )
538
 
539
  if __name__ == "__main__":
540
+ logger.info("Starting Gradio demo...")
541
  demo.launch()