yentinglin commited on
Commit
b07562d
·
verified ·
1 Parent(s): 4fa00bc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +458 -0
app.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ import logging
6
+ import sys
7
+ from typing import Optional, Literal
8
+ from pydantic import BaseModel
9
+ from transformers import pipeline
10
+ from pyannote.audio import Pipeline
11
+ from huggingface_hub import HfApi
12
+ from torchaudio import functional as F # For resampling and audio processing
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(levelname)s - %(message)s')
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # --- Configuration ---
19
+ # You will need a Hugging Face token for pyannote/speaker-diarization-3.1.
20
+ # 1. Go to https://huggingface.co/settings/tokens to create a new token.
21
+ # 2. Make sure you have accepted the user conditions on the model page:
22
+ # https://huggingface.co/pyannote/speaker-diarization-3.1
23
+ # 3. Set your token as an environment variable before running this script:
24
+ # export HF_TOKEN="hf_YOUR_TOKEN_HERE"
25
+ # Alternatively, replace os.getenv("HF_TOKEN") with your actual token string:
26
+ # HF_TOKEN = "hf_YOUR_TOKEN_HERE"
27
+ HF_TOKEN = os.getenv("HF_TOKEN")
28
+
29
+ # Model names
30
+ ASR_MODEL = "openai/whisper-small" # Smaller, faster Whisper model for demo
31
+ DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
32
+ # Speculative decoding (assistant model) is explicitly excluded as per requirements.
33
+
34
+ # --- Inference Configuration (Pydantic Model for validation) ---
35
+ class InferenceConfig(BaseModel):
36
+ task: Literal["transcribe", "translate"] = "transcribe"
37
+ batch_size: int = 24
38
+ chunk_length_s: int = 30
39
+ language: Optional[str] = None
40
+ num_speakers: Optional[int] = None
41
+ min_speakers: Optional[int] = None
42
+ max_speakers: Optional[int] = None
43
+
44
+ # --- Global Models and Device ---
45
+ models = {}
46
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
47
+ logger.info(f"Using device: {device.type}")
48
+ torch_dtype = torch.float32 if device.type == "cpu" else torch.float16 # Use float16 on GPU for efficiency
49
+
50
+ # --- Model Loading Function ---
51
+ def load_models():
52
+ """
53
+ Loads the ASR and Diarization models into the global `models` dictionary.
54
+ Handles device placement and Hugging Face token authentication.
55
+ """
56
+ logger.info("Loading ASR pipeline...")
57
+ # The ASR pipeline can directly take a numpy array for inference.
58
+ models["asr_pipeline"] = pipeline(
59
+ "automatic-speech-recognition",
60
+ model=ASR_MODEL,
61
+ torch_dtype=torch_dtype,
62
+ device=device
63
+ )
64
+ logger.info("ASR pipeline loaded.")
65
+
66
+ if DIARIZATION_MODEL:
67
+ logger.info(f"Loading Diarization pipeline: {DIARIZATION_MODEL}...")
68
+ if not HF_TOKEN:
69
+ raise ValueError(
70
+ "HF_TOKEN environment variable or HF_TOKEN constant not set. "
71
+ "Pyannote models require a Hugging Face token for authentication. "
72
+ "Get it from https://huggingface.co/settings/tokens and ensure you accept "
73
+ "the user conditions on the model page: "
74
+ "https://huggingface.co/pyannote/speaker-diarization-3.1"
75
+ )
76
+ try:
77
+ # Verify token and load pyannote pipeline
78
+ HfApi().whoami(token=HF_TOKEN) # Check token validity
79
+ models["diarization_pipeline"] = Pipeline.from_pretrained(
80
+ checkpoint_path=DIARIZATION_MODEL,
81
+ use_auth_token=HF_TOKEN,
82
+ )
83
+ models["diarization_pipeline"].to(device)
84
+ logger.info("Diarization pipeline loaded.")
85
+ except Exception as e:
86
+ logger.error(f"Failed to load diarization pipeline: {e}")
87
+ raise
88
+ else:
89
+ models["diarization_pipeline"] = None
90
+ logger.info("Diarization model not specified, diarization will be skipped.")
91
+
92
+ # Load models once when the script starts
93
+ try:
94
+ load_models()
95
+ except Exception as e:
96
+ logger.critical(f"Failed to load models. Please check your HF_TOKEN and model names. Exiting: {e}")
97
+ sys.exit(1)
98
+
99
+ # --- Diarization Utility Functions (adapted from original `model-server/app/utils/diarization_utils.py`) ---
100
+
101
+ def preprocess_audio_for_diarization(sampling_rate_in: int, audio_array_in: np.ndarray) -> tuple[torch.Tensor, int]:
102
+ """
103
+ Preprocesses audio for the diarization pipeline.
104
+ Resamples to 16kHz and ensures single channel float32 torch tensor.
105
+ """
106
+ if audio_array_in is None or audio_array_in.size == 0:
107
+ raise ValueError("Audio array is empty for diarization preprocessing.")
108
+
109
+ # Convert to float32 if not already (pyannote expects float32)
110
+ if audio_array_in.dtype != np.float32:
111
+ audio_array_in = audio_array_in.astype(np.float32)
112
+
113
+ # If stereo, take one channel (pyannote expects single channel)
114
+ if len(audio_array_in.shape) > 1:
115
+ audio_array_in = audio_array_in[:, 0] # Take the first channel
116
+
117
+ # Resample to 16kHz if necessary, as pyannote models are typically trained on 16kHz audio.
118
+ if sampling_rate_in != 16000:
119
+ audio_array_in = F.resample(
120
+ torch.from_numpy(audio_array_in), sampling_rate_in, 16000
121
+ ).numpy()
122
+ sampling_rate_in = 16000 # Update SR to reflect resampling
123
+
124
+ # Diarization model expects float32 torch tensor of shape `(channels, seq_len)`
125
+ diarizer_inputs = torch.from_numpy(audio_array_in).float()
126
+ diarizer_inputs = diarizer_inputs.unsqueeze(0) # Add channel dimension (1, seq_len)
127
+
128
+ return diarizer_inputs, sampling_rate_in
129
+
130
+ def diarize_audio(diarizer_inputs: torch.Tensor, diarization_pipeline: Pipeline, parameters: InferenceConfig) -> list:
131
+ """
132
+ Performs diarization using the pyannote pipeline and combines consecutive speaker segments.
133
+ """
134
+ # Run the diarization pipeline
135
+ diarization = diarization_pipeline(
136
+ {"waveform": diarizer_inputs, "sample_rate": 16000}, # Always pass 16kHz to diarizer
137
+ num_speakers=parameters.num_speakers,
138
+ min_speakers=parameters.min_speakers,
139
+ max_speakers=parameters.max_speakers,
140
+ )
141
+
142
+ raw_segments = []
143
+ # pyannote.audio returns segments as `Segment(start=X, end=Y)`
144
+ for segment, _, label in diarization.itertracks(yield_label=True):
145
+ raw_segments.append(
146
+ {
147
+ "segment": {"start": segment.start, "end": segment.end},
148
+ "label": label,
149
+ }
150
+ )
151
+
152
+ # Combine consecutive segments from the same speaker
153
+ combined_segments = []
154
+ if not raw_segments:
155
+ return combined_segments
156
+
157
+ # Initialize with the first segment
158
+ current_speaker_segment = {
159
+ "speaker": raw_segments[0]["label"],
160
+ "segment": {"start": raw_segments[0]["segment"]["start"], "end": raw_segments[0]["segment"]["end"]},
161
+ }
162
+
163
+ for i in range(1, len(raw_segments)):
164
+ next_segment = raw_segments[i]
165
+
166
+ # If the speaker changes
167
+ if next_segment["label"] != current_speaker_segment["speaker"]:
168
+ # Add the accumulated segment for the previous speaker
169
+ combined_segments.append(current_speaker_segment)
170
+ # Start a new segment accumulation with the current speaker
171
+ current_speaker_segment = {
172
+ "speaker": next_segment["label"],
173
+ "segment": {"start": next_segment["segment"]["start"], "end": next_segment["segment"]["end"]},
174
+ }
175
+ else:
176
+ # Same speaker, extend the end time of the current accumulated segment
177
+ current_speaker_segment["segment"]["end"] = next_segment["segment"]["end"]
178
+
179
+ # Add the very last accumulated segment after the loop finishes
180
+ combined_segments.append(current_speaker_segment)
181
+
182
+ return combined_segments
183
+
184
+ def post_process_segments_and_transcripts(combined_diarization_segments: list, asr_transcript_chunks: list) -> list:
185
+ """
186
+ Aligns combined diarization segments with ASR transcript chunks.
187
+ This logic closely follows the provided `diarization_utils.py`'s `post_process_segments_and_transcripts`
188
+ function, which uses `argmin` for alignment and slicing for chunk consumption.
189
+ """
190
+ if not asr_transcript_chunks:
191
+ return []
192
+
193
+ # Get the end timestamps for each ASR chunk
194
+ # Use sys.float_info.max for None to ensure `argmin` works
195
+ asr_end_timestamps = np.array(
196
+ [chunk["timestamp"][1] if chunk["timestamp"][1] is not None else sys.float_info.max for chunk in asr_transcript_chunks]
197
+ )
198
+
199
+ # Create mutable copies to slice from
200
+ current_asr_chunks = list(asr_transcript_chunks)
201
+ current_asr_end_timestamps = asr_end_timestamps.copy()
202
+
203
+ final_segmented_transcript = []
204
+
205
+ for diar_segment in combined_diarization_segments:
206
+ if not current_asr_chunks:
207
+ break # No more ASR chunks to process
208
+
209
+ diar_start = diar_segment["segment"]["start"]
210
+ diar_end = diar_segment["segment"]["end"]
211
+ speaker = diar_segment["speaker"]
212
+
213
+ # Find the index in `current_asr_end_timestamps` whose value is closest to `diar_end`.
214
+ # This `upto_idx_relative` determines how many ASR chunks from `current_asr_chunks`
215
+ # will be associated with the current `diar_segment`.
216
+ upto_idx_relative = np.argmin(np.abs(current_asr_end_timestamps - diar_end))
217
+
218
+ # Select the ASR chunks up to and including this `upto_idx_relative`.
219
+ chunks_for_this_diar_segment = current_asr_chunks[:upto_idx_relative + 1]
220
+
221
+ if not chunks_for_this_diar_segment:
222
+ continue # No ASR chunks found for this diarization segment, skip
223
+
224
+ # Combine the text from the selected ASR chunks.
225
+ combined_text = "".join([chunk["text"] for chunk in chunks_for_this_diar_segment]).strip()
226
+
227
+ # Determine the start and end timestamp for the combined ASR text.
228
+ # This will be the min start and max end of the involved ASR chunks.
229
+ asr_min_start = min(chunk["timestamp"][0] for chunk in chunks_for_this_diar_segment if chunk["timestamp"][0] is not None)
230
+ asr_max_end = max(chunk["timestamp"][1] for chunk in chunks_for_this_diar_segment if chunk["timestamp"][1] is not None)
231
+
232
+ # Final timestamp for the output segment should be clamped by the diarization segment's boundaries
233
+ # to ensure it doesn't extend beyond what the diarizer indicated.
234
+ final_segment_start = max(diar_start, asr_min_start)
235
+ final_segment_end = min(diar_end, asr_max_end)
236
+
237
+ final_segmented_transcript.append(
238
+ {
239
+ "speaker": speaker,
240
+ "text": combined_text,
241
+ "timestamp": (final_segment_start, final_segment_end),
242
+ }
243
+ )
244
+
245
+ # Remove the processed ASR chunks from the lists for the next iteration.
246
+ current_asr_chunks = current_asr_chunks[upto_idx_relative + 1:]
247
+ current_asr_end_timestamps = current_asr_end_timestamps[upto_idx_relative + 1:]
248
+
249
+ return final_segmented_transcript
250
+
251
+ def diarize_and_align_transcript(diarization_pipeline: Pipeline, original_sampling_rate: int,
252
+ audio_numpy_array: np.ndarray, parameters: InferenceConfig, asr_outputs: dict) -> list:
253
+ """
254
+ Orchestrates the entire diarization and transcript alignment process.
255
+ """
256
+ # 1. Preprocess audio for the diarization model (resample to 16kHz, ensure mono, convert to torch.Tensor)
257
+ diarizer_input_tensor, processed_sampling_rate = preprocess_audio_for_diarization(
258
+ original_sampling_rate, audio_numpy_array
259
+ )
260
+
261
+ # 2. Perform diarization to get speaker segments
262
+ # Update parameters with the processed sampling rate for diarization model's internal use.
263
+ diarization_params_for_pipeline = parameters.model_copy(update={"sampling_rate": processed_sampling_rate})
264
+ combined_diarization_segments = diarize_audio(
265
+ diarizer_input_tensor,
266
+ diarization_pipeline,
267
+ diarization_params_for_pipeline
268
+ )
269
+
270
+ # 3. Align diarization segments with ASR transcript chunks
271
+ aligned_transcript = post_process_segments_and_transcripts(
272
+ combined_diarization_segments, asr_outputs["chunks"]
273
+ )
274
+
275
+ return aligned_transcript
276
+
277
+ # --- Main Prediction Function for Gradio Interface ---
278
+ def predict_audio(
279
+ audio_file_tuple: tuple[int, np.ndarray],
280
+ batch_size: int,
281
+ chunk_length_s: int,
282
+ language: str,
283
+ num_speakers: Optional[int],
284
+ min_speakers: Optional[int],
285
+ max_speakers: Optional[int]
286
+ ) -> tuple[str, str, str]:
287
+ """
288
+ Gradio-compatible function to perform ASR and optionally speaker diarization.
289
+
290
+ Args:
291
+ audio_file_tuple: A tuple (sampling_rate, numpy_array) from Gradio's gr.Audio input.
292
+ batch_size: Batch size for ASR inference.
293
+ chunk_length_s: Chunk length for ASR inference in seconds.
294
+ language: Language for ASR (e.g., "English", "Auto-detect").
295
+ num_speakers: Expected number of speakers for diarization (optional).
296
+ min_speakers: Minimum number of speakers for diarization (optional).
297
+ max_speakers: Maximum number of speakers for diarization (optional).
298
+
299
+ Returns:
300
+ A tuple containing:
301
+ - formatted_diarized_text: A string with the diarized transcript.
302
+ - full_transcript_text: A string with the full ASR transcript.
303
+ - status_message: A message indicating success or failure.
304
+ """
305
+ if audio_file_tuple is None:
306
+ return "", "", "Please upload an audio file."
307
+
308
+ sampling_rate, audio_numpy_array = audio_file_tuple
309
+
310
+ if audio_numpy_array is None or audio_numpy_array.size == 0:
311
+ return "", "", "Audio file is empty. Please upload a valid audio."
312
+
313
+ # Ensure audio_numpy_array is float32 as expected by transformers pipeline
314
+ if audio_numpy_array.dtype != np.float32:
315
+ audio_numpy_array = audio_numpy_array.astype(np.float32)
316
+
317
+ # If stereo, convert to mono for consistent processing (e.g., take the first channel)
318
+ if len(audio_numpy_array.shape) > 1:
319
+ audio_numpy_array = audio_numpy_array[:, 0]
320
+
321
+ # Create an InferenceConfig object from Gradio inputs for internal validation and use.
322
+ try:
323
+ parameters = InferenceConfig(
324
+ batch_size=batch_size,
325
+ chunk_length_s=chunk_length_s,
326
+ language=language if language != "Auto-detect" else None, # Convert "Auto-detect" to None for model
327
+ num_speakers=num_speakers,
328
+ min_speakers=min_speakers,
329
+ max_speakers=max_speakers,
330
+ )
331
+ except Exception as e:
332
+ logger.error(f"Error validating parameters: {e}")
333
+ return "", "", f"Error validating input parameters: {e}"
334
+
335
+ logger.info(f"Inference parameters: {parameters.model_dump_json()}")
336
+ logger.info(f"Audio sampling rate: {sampling_rate} Hz, Audio shape: {audio_numpy_array.shape}")
337
+
338
+ asr_pipeline = models.get("asr_pipeline")
339
+ diarization_pipeline = models.get("diarization_pipeline")
340
+
341
+ if not asr_pipeline:
342
+ return "", "", "ASR model not loaded. Please restart the application."
343
+
344
+ # Prepare ASR generation arguments
345
+ generate_kwargs = {
346
+ "task": parameters.task,
347
+ "language": parameters.language,
348
+ "assistant_model": None # Speculative decoding is disabled
349
+ }
350
+
351
+ asr_outputs = None
352
+ try:
353
+ logger.info("Starting ASR inference...")
354
+ asr_outputs = asr_pipeline(
355
+ audio_numpy_array, # Pass numpy array directly
356
+ chunk_length_s=parameters.chunk_length_s,
357
+ batch_size=parameters.batch_size,
358
+ generate_kwargs=generate_kwargs,
359
+ return_timestamps=True,
360
+ sampling_rate=sampling_rate # Pass original sampling rate to pipeline
361
+ )
362
+ logger.info("ASR inference completed.")
363
+ except Exception as e:
364
+ logger.error(f"ASR inference error: {str(e)}")
365
+ return "", "", f"ASR inference error: {str(e)}"
366
+
367
+ final_transcript_data = []
368
+ status_message = ""
369
+
370
+ if diarization_pipeline:
371
+ try:
372
+ logger.info("Starting Diarization inference and alignment...")
373
+ final_transcript_data = diarize_and_align_transcript(
374
+ diarization_pipeline, sampling_rate, audio_numpy_array, parameters, asr_outputs
375
+ )
376
+ status_message = "Diarization and ASR successful!"
377
+ logger.info("Diarization and alignment completed.")
378
+ except Exception as e:
379
+ logger.error(f"Diarization inference error: {str(e)}")
380
+ # If diarization fails, still provide the full ASR transcript
381
+ final_transcript_data = [] # Clear any partial diarization
382
+ status_message = f"Diarization failed: {str(e)}. Displaying full ASR transcript only."
383
+ else:
384
+ logger.info("Diarization pipeline not loaded, skipping diarization and returning raw ASR chunks.")
385
+ # If no diarization, format ASR chunks as if they were from a single "Speaker"
386
+ for chunk in asr_outputs["chunks"]:
387
+ final_transcript_data.append({
388
+ "speaker": "Speaker", # Generic label
389
+ "text": chunk["text"],
390
+ "timestamp": chunk["timestamp"]
391
+ })
392
+ status_message = "Diarization not enabled. Displaying full ASR transcript by chunk."
393
+
394
+ # Format the output for Gradio display
395
+ formatted_diarized_text_output = []
396
+ for entry in final_transcript_data:
397
+ start_time = f"{entry['timestamp'][0]:.2f}" if entry['timestamp'][0] is not None else "0.00"
398
+ end_time = f"{entry['timestamp'][1]:.2f}" if entry['timestamp'][1] is not None else "End"
399
+ formatted_diarized_text_output.append(
400
+ f"[{start_time} - {end_time}] {entry['speaker']}: {entry['text'].strip()}"
401
+ )
402
+
403
+ full_asr_text_output = asr_outputs["text"] if asr_outputs else "No ASR transcript generated."
404
+
405
+ return (
406
+ "\n".join(formatted_diarized_text_output),
407
+ full_asr_text_output,
408
+ status_message
409
+ )
410
+
411
+ # --- Gradio Interface Definition ---
412
+
413
+ # List of languages supported by OpenAI Whisper models
414
+ WHISPER_LANGUAGES = [
415
+ "Auto-detect", "English", "Chinese", "German", "Spanish", "Russian", "Korean", "French", "Japanese", "Portuguese",
416
+ "Turkish", "Polish", "Catalan", "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi", "Finnish",
417
+ "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay", "Czech", "Romanian", "Danish", "Hungarian", "Tamil",
418
+ "Norwegian", "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin", "Maori", "Malayalam", "Afrikaans",
419
+ "Welsh", "Belarusian", "Gujarati", "Kannada", "Armenian", "Azerbaijani", "Serbian", "Slovenian", "Estonian",
420
+ "Burmese", "Galician", "Mongolian", "Lao", "Kazakh", "Georgian", "Amharic", "Nepali", "Bosnian", "Luxembourgish",
421
+ "Pashto", "Tagalog", "Malagasy", "Albanian", "Sindhi", "Kurdish", "Somali", "Telugu", "Tajik", "Swahili",
422
+ "Kashmiri"
423
+ ]
424
+
425
+ demo = gr.Interface(
426
+ fn=predict_audio,
427
+ inputs=[
428
+ gr.Audio(type="numpy", label="Upload Audio File (WAV, MP3, FLAC, etc.)"),
429
+ gr.Slider(minimum=1, maximum=32, value=24, step=1, label="ASR Batch Size"),
430
+ gr.Slider(minimum=1, maximum=60, value=30, step=1, label="ASR Chunk Length (seconds)"),
431
+ gr.Dropdown(WHISPER_LANGUAGES, value="Auto-detect", label="ASR Language"),
432
+ gr.Number(label="Diarization: Number of Speakers (optional)", value=None, precision=0, info="Expected total number of speakers."),
433
+ gr.Number(label="Diarization: Min Speakers (optional)", value=None, precision=0, info="Minimum number of speakers to detect."),
434
+ gr.Number(label="Diarization: Max Speakers (optional)", value=None, precision=0, info="Maximum number of speakers to detect.")
435
+ ],
436
+ outputs=[
437
+ gr.Textbox(label="Diarized Transcript", lines=10, interactive=False),
438
+ gr.Textbox(label="Full ASR Transcript", lines=5, interactive=False),
439
+ gr.Textbox(label="Status Message", lines=1, interactive=False)
440
+ ],
441
+ title="Whisper ASR with Pyannote Speaker Diarization",
442
+ description=(
443
+ "Upload an audio file to get a transcript with speaker diarization. "
444
+ "This demo uses `openai/whisper-small` for ASR and `pyannote/speaker-diarization-3.1` for diarization. "
445
+ "A Hugging Face token with access to `pyannote/speaker-diarization-3.1` is required. "
446
+ "Please set it as an `HF_TOKEN` environment variable before launching (see script comments)."
447
+ "<br><b>Note:</b> For long audios or high concurrent usage, consider using a GPU and models like `whisper-large-v3`."
448
+ ),
449
+ allow_flagging="never", # Disable Gradio flagging feature
450
+ # Example audio path assumes you are running from the cloned repository root.
451
+ # If not, download a small WAV file (e.g., from Common Voice) and update this path.
452
+ examples=[
453
+ [os.path.join(os.path.dirname(__file__), "model-server", "app", "tests", "polyai-minds14-0.wav"), 24, 30, "Auto-detect", None, None, None]
454
+ ]
455
+ )
456
+
457
+ if __name__ == "__main__":
458
+ demo.launch()