Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,46 @@ from pyannote.audio import Pipeline
|
|
11 |
from huggingface_hub import HfApi
|
12 |
from torchaudio import functional as F # For resampling and audio processing
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# Set up logging
|
15 |
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(levelname)s - %(message)s')
|
16 |
logger = logging.getLogger(__name__)
|
@@ -27,14 +67,14 @@ logger = logging.getLogger(__name__)
|
|
27 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
28 |
|
29 |
# Model names
|
30 |
-
ASR_MODEL = "openai/whisper-
|
31 |
DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
|
32 |
# Speculative decoding (assistant model) is explicitly excluded as per requirements.
|
33 |
|
34 |
# --- Inference Configuration (Pydantic Model for validation) ---
|
35 |
class InferenceConfig(BaseModel):
|
36 |
task: Literal["transcribe", "translate"] = "transcribe"
|
37 |
-
batch_size: int =
|
38 |
chunk_length_s: int = 30
|
39 |
language: Optional[str] = None
|
40 |
num_speakers: Optional[int] = None
|
@@ -210,29 +250,46 @@ def post_process_segments_and_transcripts(combined_diarization_segments: list, a
|
|
210 |
diar_end = diar_segment["segment"]["end"]
|
211 |
speaker = diar_segment["speaker"]
|
212 |
|
213 |
-
# Find the index
|
214 |
-
#
|
215 |
-
|
|
|
|
|
|
|
216 |
upto_idx_relative = np.argmin(np.abs(current_asr_end_timestamps - diar_end))
|
217 |
|
218 |
-
# Select the ASR chunks up to and including this `upto_idx_relative`.
|
219 |
chunks_for_this_diar_segment = current_asr_chunks[:upto_idx_relative + 1]
|
220 |
|
221 |
if not chunks_for_this_diar_segment:
|
222 |
-
|
|
|
223 |
|
224 |
-
#
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
-
#
|
233 |
-
|
234 |
-
|
235 |
-
final_segment_end = min(diar_end, asr_max_end)
|
236 |
|
237 |
final_segmented_transcript.append(
|
238 |
{
|
@@ -242,12 +299,13 @@ def post_process_segments_and_transcripts(combined_diarization_segments: list, a
|
|
242 |
}
|
243 |
)
|
244 |
|
245 |
-
#
|
246 |
current_asr_chunks = current_asr_chunks[upto_idx_relative + 1:]
|
247 |
current_asr_end_timestamps = current_asr_end_timestamps[upto_idx_relative + 1:]
|
248 |
|
249 |
return final_segmented_transcript
|
250 |
|
|
|
251 |
def diarize_and_align_transcript(diarization_pipeline: Pipeline, original_sampling_rate: int,
|
252 |
audio_numpy_array: np.ndarray, parameters: InferenceConfig, asr_outputs: dict) -> list:
|
253 |
"""
|
@@ -303,12 +361,12 @@ def predict_audio(
|
|
303 |
- status_message: A message indicating success or failure.
|
304 |
"""
|
305 |
if audio_file_tuple is None:
|
306 |
-
return "", "", "Please upload an audio file."
|
307 |
|
308 |
sampling_rate, audio_numpy_array = audio_file_tuple
|
309 |
|
310 |
if audio_numpy_array is None or audio_numpy_array.size == 0:
|
311 |
-
return "", "", "Audio file is empty. Please upload a valid audio."
|
312 |
|
313 |
# Ensure audio_numpy_array is float32 as expected by transformers pipeline
|
314 |
if audio_numpy_array.dtype != np.float32:
|
@@ -318,19 +376,34 @@ def predict_audio(
|
|
318 |
if len(audio_numpy_array.shape) > 1:
|
319 |
audio_numpy_array = audio_numpy_array[:, 0]
|
320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
# Create an InferenceConfig object from Gradio inputs for internal validation and use.
|
322 |
try:
|
323 |
parameters = InferenceConfig(
|
324 |
batch_size=batch_size,
|
325 |
chunk_length_s=chunk_length_s,
|
326 |
language=language if language != "Auto-detect" else None, # Convert "Auto-detect" to None for model
|
327 |
-
num_speakers=
|
328 |
-
min_speakers=
|
329 |
-
max_speakers=
|
330 |
)
|
331 |
except Exception as e:
|
332 |
logger.error(f"Error validating parameters: {e}")
|
333 |
-
return "", "", f"Error validating input parameters: {e}"
|
334 |
|
335 |
logger.info(f"Inference parameters: {parameters.model_dump_json()}")
|
336 |
logger.info(f"Audio sampling rate: {sampling_rate} Hz, Audio shape: {audio_numpy_array.shape}")
|
@@ -339,7 +412,14 @@ def predict_audio(
|
|
339 |
diarization_pipeline = models.get("diarization_pipeline")
|
340 |
|
341 |
if not asr_pipeline:
|
342 |
-
return "", "", "ASR model not loaded. Please restart the application."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
# Prepare ASR generation arguments
|
345 |
generate_kwargs = {
|
@@ -357,12 +437,12 @@ def predict_audio(
|
|
357 |
batch_size=parameters.batch_size,
|
358 |
generate_kwargs=generate_kwargs,
|
359 |
return_timestamps=True,
|
360 |
-
#sampling_rate=sampling_rate # Pass original sampling rate to pipeline
|
361 |
)
|
362 |
logger.info("ASR inference completed.")
|
363 |
except Exception as e:
|
364 |
logger.error(f"ASR inference error: {str(e)}")
|
365 |
-
return "", "", f"ASR inference error: {str(e)}"
|
366 |
|
367 |
final_transcript_data = []
|
368 |
status_message = ""
|
@@ -426,12 +506,12 @@ demo = gr.Interface(
|
|
426 |
fn=predict_audio,
|
427 |
inputs=[
|
428 |
gr.Audio(type="numpy", label="Upload Audio File (WAV, MP3, FLAC, etc.)"),
|
429 |
-
gr.Slider(minimum=1, maximum=32, value=
|
430 |
-
gr.Slider(minimum=1, maximum=
|
431 |
-
gr.Dropdown(WHISPER_LANGUAGES, value="
|
432 |
-
gr.Number(label="Diarization: Number of Speakers (optional)", value=None, precision=0, info="Expected total number of speakers."),
|
433 |
-
gr.Number(label="Diarization: Min Speakers (optional)", value=None, precision=0, info="Minimum number of speakers to detect."),
|
434 |
-
gr.Number(label="Diarization: Max Speakers (optional)", value=None, precision=0, info="Maximum number of speakers to detect.")
|
435 |
],
|
436 |
outputs=[
|
437 |
gr.Textbox(label="Diarized Transcript", lines=10, interactive=False),
|
@@ -447,13 +527,15 @@ demo = gr.Interface(
|
|
447 |
"<br><b>Note:</b> For long audios or high concurrent usage, consider using a GPU and models like `whisper-large-v3`."
|
448 |
),
|
449 |
allow_flagging="never", # Disable Gradio flagging feature
|
450 |
-
# Example audio path assumes you are running from the cloned repository root.
|
451 |
-
# If not, download a small WAV file (e.g., from Common Voice) and update this path.
|
452 |
examples=[
|
|
|
|
|
|
|
453 |
[os.path.join(os.path.dirname(__file__), "model-server", "app", "tests", "polyai-minds14-0.wav"), 24, 30, "Auto-detect", None, None, None]
|
454 |
],
|
455 |
-
cache_examples=False
|
456 |
)
|
457 |
|
458 |
if __name__ == "__main__":
|
|
|
459 |
demo.launch()
|
|
|
11 |
from huggingface_hub import HfApi
|
12 |
from torchaudio import functional as F # For resampling and audio processing
|
13 |
|
14 |
+
# To run this Gradio demo, first ensure you have Python 3.9+ installed.
|
15 |
+
# Then, create a virtual environment and install the required packages.
|
16 |
+
#
|
17 |
+
# 1. Create and activate a virtual environment (recommended):
|
18 |
+
# python3 -m venv venv
|
19 |
+
# source venv/bin/activate # On Linux/macOS
|
20 |
+
# venv\Scripts\activate # On Windows
|
21 |
+
#
|
22 |
+
# 2. Install the necessary packages:
|
23 |
+
# pip install gradio==4.20.1 \
|
24 |
+
# torch==2.2.1 \
|
25 |
+
# torchaudio==2.2.1 \
|
26 |
+
# transformers==4.38.2 \
|
27 |
+
# pyannote-audio==3.1.1 \
|
28 |
+
# numpy==1.26.4 \
|
29 |
+
# fastapi==0.110.0 \
|
30 |
+
# uvicorn==0.27.1 \
|
31 |
+
# python-multipart==0.0.9 \
|
32 |
+
# pydantic==2.6.1 \
|
33 |
+
# soundfile==0.12.1 # Required by torchaudio and pyannote for certain audio formats
|
34 |
+
#
|
35 |
+
# # If you have a CUDA-compatible GPU, install the CUDA version of PyTorch instead:
|
36 |
+
# # For CUDA 12.1 (adjust 'cu121' to your CUDA version, e.g., 'cu118' for CUDA 11.8):
|
37 |
+
# # pip install torch==2.2.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
|
38 |
+
#
|
39 |
+
# 3. Set your Hugging Face Token (required for pyannote/speaker-diarization-3.1).
|
40 |
+
# You must accept the user conditions on the model page: https://huggingface.co/pyannote/speaker-diarization-3.1
|
41 |
+
# export HF_TOKEN="hf_YOUR_TOKEN_HERE"
|
42 |
+
# # Or directly in the script (not recommended for security):
|
43 |
+
# # HF_TOKEN = "hf_YOUR_TOKEN_HERE"
|
44 |
+
#
|
45 |
+
# 4. Save this file as, for example, `app.py`.
|
46 |
+
#
|
47 |
+
# 5. Run the application using uvicorn (as requested):
|
48 |
+
# uvicorn app:demo --host 0.0.0.0 --port 7860
|
49 |
+
#
|
50 |
+
# # If you just want to run it via Python script (Gradio's default, without uvicorn directly):
|
51 |
+
# # python app.py
|
52 |
+
|
53 |
+
|
54 |
# Set up logging
|
55 |
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(levelname)s - %(message)s')
|
56 |
logger = logging.getLogger(__name__)
|
|
|
67 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
68 |
|
69 |
# Model names
|
70 |
+
ASR_MODEL = "openai/whisper-large-v3-turbo" # Smaller, faster Whisper model for demo
|
71 |
DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
|
72 |
# Speculative decoding (assistant model) is explicitly excluded as per requirements.
|
73 |
|
74 |
# --- Inference Configuration (Pydantic Model for validation) ---
|
75 |
class InferenceConfig(BaseModel):
|
76 |
task: Literal["transcribe", "translate"] = "transcribe"
|
77 |
+
batch_size: int = 1
|
78 |
chunk_length_s: int = 30
|
79 |
language: Optional[str] = None
|
80 |
num_speakers: Optional[int] = None
|
|
|
250 |
diar_end = diar_segment["segment"]["end"]
|
251 |
speaker = diar_segment["speaker"]
|
252 |
|
253 |
+
# Find the index of the ASR chunk whose end timestamp is closest to diar_end
|
254 |
+
# Ensure argmin operates on a non-empty array
|
255 |
+
if current_asr_end_timestamps.size == 0:
|
256 |
+
logger.warning("No ASR end timestamps left to align with diarization segment. Breaking alignment.")
|
257 |
+
break # No more ASR chunks to align
|
258 |
+
|
259 |
upto_idx_relative = np.argmin(np.abs(current_asr_end_timestamps - diar_end))
|
260 |
|
|
|
261 |
chunks_for_this_diar_segment = current_asr_chunks[:upto_idx_relative + 1]
|
262 |
|
263 |
if not chunks_for_this_diar_segment:
|
264 |
+
logger.warning(f"No ASR chunks selected for diarization segment [{diar_start:.2f}-{diar_end:.2f}] {speaker}. Skipping.")
|
265 |
+
continue
|
266 |
|
267 |
+
# Initialize with extreme values to find min/max correctly, handling None timestamps
|
268 |
+
asr_min_start_val = float('inf')
|
269 |
+
asr_max_end_val = float('-inf')
|
270 |
+
|
271 |
+
all_text = []
|
272 |
+
|
273 |
+
for chunk in chunks_for_this_diar_segment:
|
274 |
+
all_text.append(chunk["text"])
|
275 |
+
if chunk["timestamp"] and chunk["timestamp"][0] is not None:
|
276 |
+
asr_min_start_val = min(asr_min_start_val, chunk["timestamp"][0])
|
277 |
+
if chunk["timestamp"] and chunk["timestamp"][1] is not None:
|
278 |
+
asr_max_end_val = max(asr_max_end_val, chunk["timestamp"][1])
|
279 |
|
280 |
+
combined_text = "".join(all_text).strip()
|
281 |
+
|
282 |
+
# If no valid timestamps were found in the selected ASR chunks, fall back to diarization segment's bounds
|
283 |
+
if asr_min_start_val == float('inf'):
|
284 |
+
logger.warning(f"No valid start timestamps in ASR chunks for segment [{diar_start:.2f}-{diar_end:.2f}] {speaker}. Using diarization start.")
|
285 |
+
asr_min_start_val = diar_start
|
286 |
+
if asr_max_end_val == float('-inf'):
|
287 |
+
logger.warning(f"No valid end timestamps in ASR chunks for segment [{diar_start:.2f}-{diar_end:.2f}] {speaker}. Using diarization end.")
|
288 |
+
asr_max_end_val = diar_end
|
289 |
|
290 |
+
# Ensure final timestamp range makes sense and is clamped by diarization segment
|
291 |
+
final_segment_start = max(diar_start, asr_min_start_val)
|
292 |
+
final_segment_end = min(diar_end, asr_max_end_val)
|
|
|
293 |
|
294 |
final_segmented_transcript.append(
|
295 |
{
|
|
|
299 |
}
|
300 |
)
|
301 |
|
302 |
+
# Crop the transcripts and timestamp lists according to the latest timestamp
|
303 |
current_asr_chunks = current_asr_chunks[upto_idx_relative + 1:]
|
304 |
current_asr_end_timestamps = current_asr_end_timestamps[upto_idx_relative + 1:]
|
305 |
|
306 |
return final_segmented_transcript
|
307 |
|
308 |
+
|
309 |
def diarize_and_align_transcript(diarization_pipeline: Pipeline, original_sampling_rate: int,
|
310 |
audio_numpy_array: np.ndarray, parameters: InferenceConfig, asr_outputs: dict) -> list:
|
311 |
"""
|
|
|
361 |
- status_message: A message indicating success or failure.
|
362 |
"""
|
363 |
if audio_file_tuple is None:
|
364 |
+
return "", "", gr.Warning("Please upload an audio file.")
|
365 |
|
366 |
sampling_rate, audio_numpy_array = audio_file_tuple
|
367 |
|
368 |
if audio_numpy_array is None or audio_numpy_array.size == 0:
|
369 |
+
return "", "", gr.Warning("Audio file is empty. Please upload a valid audio.")
|
370 |
|
371 |
# Ensure audio_numpy_array is float32 as expected by transformers pipeline
|
372 |
if audio_numpy_array.dtype != np.float32:
|
|
|
376 |
if len(audio_numpy_array.shape) > 1:
|
377 |
audio_numpy_array = audio_numpy_array[:, 0]
|
378 |
|
379 |
+
# Process speaker parameters: convert 0 or negative values to None for pyannote compatibility
|
380 |
+
processed_num_speakers = num_speakers if num_speakers is not None and num_speakers > 0 else None
|
381 |
+
processed_min_speakers = min_speakers if min_speakers is not None and min_speakers > 0 else None
|
382 |
+
processed_max_speakers = max_speakers if max_speakers is not None and max_speakers > 0 else None
|
383 |
+
|
384 |
+
# Validation logic for min/max speakers
|
385 |
+
if processed_min_speakers is not None and processed_max_speakers is not None and processed_min_speakers > processed_max_speakers:
|
386 |
+
return "", "", gr.Warning("Diarization: Min Speakers cannot be greater than Max Speakers.")
|
387 |
+
if processed_num_speakers is not None:
|
388 |
+
if processed_min_speakers is not None and processed_num_speakers < processed_min_speakers:
|
389 |
+
return "", "", gr.Warning("Diarization: Number of Speakers cannot be less than Min Speakers.")
|
390 |
+
if processed_max_speakers is not None and processed_num_speakers > processed_max_speakers:
|
391 |
+
return "", "", gr.Warning("Diarization: Number of Speakers cannot be greater than Max Speakers.")
|
392 |
+
|
393 |
+
|
394 |
# Create an InferenceConfig object from Gradio inputs for internal validation and use.
|
395 |
try:
|
396 |
parameters = InferenceConfig(
|
397 |
batch_size=batch_size,
|
398 |
chunk_length_s=chunk_length_s,
|
399 |
language=language if language != "Auto-detect" else None, # Convert "Auto-detect" to None for model
|
400 |
+
num_speakers=processed_num_speakers,
|
401 |
+
min_speakers=processed_min_speakers,
|
402 |
+
max_speakers=processed_max_speakers,
|
403 |
)
|
404 |
except Exception as e:
|
405 |
logger.error(f"Error validating parameters: {e}")
|
406 |
+
return "", "", gr.Error(f"Error validating input parameters: {e}") # Use gr.Error for critical validation failures
|
407 |
|
408 |
logger.info(f"Inference parameters: {parameters.model_dump_json()}")
|
409 |
logger.info(f"Audio sampling rate: {sampling_rate} Hz, Audio shape: {audio_numpy_array.shape}")
|
|
|
412 |
diarization_pipeline = models.get("diarization_pipeline")
|
413 |
|
414 |
if not asr_pipeline:
|
415 |
+
return "", "", gr.Error("ASR model not loaded. Please restart the application.")
|
416 |
+
|
417 |
+
# ASR language and batch size conflict warning/error
|
418 |
+
if parameters.language is None and parameters.batch_size > 1:
|
419 |
+
return "", "", gr.Warning(
|
420 |
+
"ASR: 'Auto-detect' language is not supported with batch size > 1. "
|
421 |
+
"Please select a specific language or set batch size to 1."
|
422 |
+
)
|
423 |
|
424 |
# Prepare ASR generation arguments
|
425 |
generate_kwargs = {
|
|
|
437 |
batch_size=parameters.batch_size,
|
438 |
generate_kwargs=generate_kwargs,
|
439 |
return_timestamps=True,
|
440 |
+
# sampling_rate=sampling_rate # Pass original sampling rate to pipeline
|
441 |
)
|
442 |
logger.info("ASR inference completed.")
|
443 |
except Exception as e:
|
444 |
logger.error(f"ASR inference error: {str(e)}")
|
445 |
+
return "", "", gr.Error(f"ASR inference error: {str(e)}")
|
446 |
|
447 |
final_transcript_data = []
|
448 |
status_message = ""
|
|
|
506 |
fn=predict_audio,
|
507 |
inputs=[
|
508 |
gr.Audio(type="numpy", label="Upload Audio File (WAV, MP3, FLAC, etc.)"),
|
509 |
+
gr.Slider(minimum=1, maximum=32, value=1, step=1, label="ASR Batch Size"),
|
510 |
+
gr.Slider(minimum=1, maximum=30, value=30, step=1, label="ASR Chunk Length (seconds)"),
|
511 |
+
gr.Dropdown(WHISPER_LANGUAGES, value="Chinese", label="ASR Language"),
|
512 |
+
gr.Number(label="Diarization: Number of Speakers (optional)", value=None, precision=0, info="Expected total number of speakers (positive integer, or leave empty for auto-detect)."),
|
513 |
+
gr.Number(label="Diarization: Min Speakers (optional)", value=None, precision=0, info="Minimum number of speakers to detect (positive integer, or leave empty for auto-detect)."),
|
514 |
+
gr.Number(label="Diarization: Max Speakers (optional)", value=None, precision=0, info="Maximum number of speakers to detect (positive integer, or leave empty for auto-detect).")
|
515 |
],
|
516 |
outputs=[
|
517 |
gr.Textbox(label="Diarized Transcript", lines=10, interactive=False),
|
|
|
527 |
"<br><b>Note:</b> For long audios or high concurrent usage, consider using a GPU and models like `whisper-large-v3`."
|
528 |
),
|
529 |
allow_flagging="never", # Disable Gradio flagging feature
|
|
|
|
|
530 |
examples=[
|
531 |
+
# Adjust this path if the `model-server/app/tests/` directory is not alongside your `app.py`
|
532 |
+
# For example, if app.py is in the root, and the audio is in a tests/ subdirectory,
|
533 |
+
# you might use: ["tests/polyai-minds14-0.wav", 24, 30, "Auto-detect", None, None, None]
|
534 |
[os.path.join(os.path.dirname(__file__), "model-server", "app", "tests", "polyai-minds14-0.wav"), 24, 30, "Auto-detect", None, None, None]
|
535 |
],
|
536 |
+
cache_examples=False # Disable caching of examples to prevent InvalidPathError
|
537 |
)
|
538 |
|
539 |
if __name__ == "__main__":
|
540 |
+
logger.info("Starting Gradio demo...")
|
541 |
demo.launch()
|