Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
import logging
|
6 |
+
import sys
|
7 |
+
from typing import Optional, Literal
|
8 |
+
from pydantic import BaseModel
|
9 |
+
from transformers import pipeline
|
10 |
+
from pyannote.audio import Pipeline
|
11 |
+
from huggingface_hub import HfApi
|
12 |
+
from torchaudio import functional as F # For resampling and audio processing
|
13 |
+
|
14 |
+
# Set up logging
|
15 |
+
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(levelname)s - %(message)s')
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
# --- Configuration ---
|
19 |
+
# You will need a Hugging Face token for pyannote/speaker-diarization-3.1.
|
20 |
+
# 1. Go to https://huggingface.co/settings/tokens to create a new token.
|
21 |
+
# 2. Make sure you have accepted the user conditions on the model page:
|
22 |
+
# https://huggingface.co/pyannote/speaker-diarization-3.1
|
23 |
+
# 3. Set your token as an environment variable before running this script:
|
24 |
+
# export HF_TOKEN="hf_YOUR_TOKEN_HERE"
|
25 |
+
# Alternatively, replace os.getenv("HF_TOKEN") with your actual token string:
|
26 |
+
# HF_TOKEN = "hf_YOUR_TOKEN_HERE"
|
27 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
28 |
+
|
29 |
+
# Model names
|
30 |
+
ASR_MODEL = "openai/whisper-small" # Smaller, faster Whisper model for demo
|
31 |
+
DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
|
32 |
+
# Speculative decoding (assistant model) is explicitly excluded as per requirements.
|
33 |
+
|
34 |
+
# --- Inference Configuration (Pydantic Model for validation) ---
|
35 |
+
class InferenceConfig(BaseModel):
|
36 |
+
task: Literal["transcribe", "translate"] = "transcribe"
|
37 |
+
batch_size: int = 24
|
38 |
+
chunk_length_s: int = 30
|
39 |
+
language: Optional[str] = None
|
40 |
+
num_speakers: Optional[int] = None
|
41 |
+
min_speakers: Optional[int] = None
|
42 |
+
max_speakers: Optional[int] = None
|
43 |
+
|
44 |
+
# --- Global Models and Device ---
|
45 |
+
models = {}
|
46 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
47 |
+
logger.info(f"Using device: {device.type}")
|
48 |
+
torch_dtype = torch.float32 if device.type == "cpu" else torch.float16 # Use float16 on GPU for efficiency
|
49 |
+
|
50 |
+
# --- Model Loading Function ---
|
51 |
+
def load_models():
|
52 |
+
"""
|
53 |
+
Loads the ASR and Diarization models into the global `models` dictionary.
|
54 |
+
Handles device placement and Hugging Face token authentication.
|
55 |
+
"""
|
56 |
+
logger.info("Loading ASR pipeline...")
|
57 |
+
# The ASR pipeline can directly take a numpy array for inference.
|
58 |
+
models["asr_pipeline"] = pipeline(
|
59 |
+
"automatic-speech-recognition",
|
60 |
+
model=ASR_MODEL,
|
61 |
+
torch_dtype=torch_dtype,
|
62 |
+
device=device
|
63 |
+
)
|
64 |
+
logger.info("ASR pipeline loaded.")
|
65 |
+
|
66 |
+
if DIARIZATION_MODEL:
|
67 |
+
logger.info(f"Loading Diarization pipeline: {DIARIZATION_MODEL}...")
|
68 |
+
if not HF_TOKEN:
|
69 |
+
raise ValueError(
|
70 |
+
"HF_TOKEN environment variable or HF_TOKEN constant not set. "
|
71 |
+
"Pyannote models require a Hugging Face token for authentication. "
|
72 |
+
"Get it from https://huggingface.co/settings/tokens and ensure you accept "
|
73 |
+
"the user conditions on the model page: "
|
74 |
+
"https://huggingface.co/pyannote/speaker-diarization-3.1"
|
75 |
+
)
|
76 |
+
try:
|
77 |
+
# Verify token and load pyannote pipeline
|
78 |
+
HfApi().whoami(token=HF_TOKEN) # Check token validity
|
79 |
+
models["diarization_pipeline"] = Pipeline.from_pretrained(
|
80 |
+
checkpoint_path=DIARIZATION_MODEL,
|
81 |
+
use_auth_token=HF_TOKEN,
|
82 |
+
)
|
83 |
+
models["diarization_pipeline"].to(device)
|
84 |
+
logger.info("Diarization pipeline loaded.")
|
85 |
+
except Exception as e:
|
86 |
+
logger.error(f"Failed to load diarization pipeline: {e}")
|
87 |
+
raise
|
88 |
+
else:
|
89 |
+
models["diarization_pipeline"] = None
|
90 |
+
logger.info("Diarization model not specified, diarization will be skipped.")
|
91 |
+
|
92 |
+
# Load models once when the script starts
|
93 |
+
try:
|
94 |
+
load_models()
|
95 |
+
except Exception as e:
|
96 |
+
logger.critical(f"Failed to load models. Please check your HF_TOKEN and model names. Exiting: {e}")
|
97 |
+
sys.exit(1)
|
98 |
+
|
99 |
+
# --- Diarization Utility Functions (adapted from original `model-server/app/utils/diarization_utils.py`) ---
|
100 |
+
|
101 |
+
def preprocess_audio_for_diarization(sampling_rate_in: int, audio_array_in: np.ndarray) -> tuple[torch.Tensor, int]:
|
102 |
+
"""
|
103 |
+
Preprocesses audio for the diarization pipeline.
|
104 |
+
Resamples to 16kHz and ensures single channel float32 torch tensor.
|
105 |
+
"""
|
106 |
+
if audio_array_in is None or audio_array_in.size == 0:
|
107 |
+
raise ValueError("Audio array is empty for diarization preprocessing.")
|
108 |
+
|
109 |
+
# Convert to float32 if not already (pyannote expects float32)
|
110 |
+
if audio_array_in.dtype != np.float32:
|
111 |
+
audio_array_in = audio_array_in.astype(np.float32)
|
112 |
+
|
113 |
+
# If stereo, take one channel (pyannote expects single channel)
|
114 |
+
if len(audio_array_in.shape) > 1:
|
115 |
+
audio_array_in = audio_array_in[:, 0] # Take the first channel
|
116 |
+
|
117 |
+
# Resample to 16kHz if necessary, as pyannote models are typically trained on 16kHz audio.
|
118 |
+
if sampling_rate_in != 16000:
|
119 |
+
audio_array_in = F.resample(
|
120 |
+
torch.from_numpy(audio_array_in), sampling_rate_in, 16000
|
121 |
+
).numpy()
|
122 |
+
sampling_rate_in = 16000 # Update SR to reflect resampling
|
123 |
+
|
124 |
+
# Diarization model expects float32 torch tensor of shape `(channels, seq_len)`
|
125 |
+
diarizer_inputs = torch.from_numpy(audio_array_in).float()
|
126 |
+
diarizer_inputs = diarizer_inputs.unsqueeze(0) # Add channel dimension (1, seq_len)
|
127 |
+
|
128 |
+
return diarizer_inputs, sampling_rate_in
|
129 |
+
|
130 |
+
def diarize_audio(diarizer_inputs: torch.Tensor, diarization_pipeline: Pipeline, parameters: InferenceConfig) -> list:
|
131 |
+
"""
|
132 |
+
Performs diarization using the pyannote pipeline and combines consecutive speaker segments.
|
133 |
+
"""
|
134 |
+
# Run the diarization pipeline
|
135 |
+
diarization = diarization_pipeline(
|
136 |
+
{"waveform": diarizer_inputs, "sample_rate": 16000}, # Always pass 16kHz to diarizer
|
137 |
+
num_speakers=parameters.num_speakers,
|
138 |
+
min_speakers=parameters.min_speakers,
|
139 |
+
max_speakers=parameters.max_speakers,
|
140 |
+
)
|
141 |
+
|
142 |
+
raw_segments = []
|
143 |
+
# pyannote.audio returns segments as `Segment(start=X, end=Y)`
|
144 |
+
for segment, _, label in diarization.itertracks(yield_label=True):
|
145 |
+
raw_segments.append(
|
146 |
+
{
|
147 |
+
"segment": {"start": segment.start, "end": segment.end},
|
148 |
+
"label": label,
|
149 |
+
}
|
150 |
+
)
|
151 |
+
|
152 |
+
# Combine consecutive segments from the same speaker
|
153 |
+
combined_segments = []
|
154 |
+
if not raw_segments:
|
155 |
+
return combined_segments
|
156 |
+
|
157 |
+
# Initialize with the first segment
|
158 |
+
current_speaker_segment = {
|
159 |
+
"speaker": raw_segments[0]["label"],
|
160 |
+
"segment": {"start": raw_segments[0]["segment"]["start"], "end": raw_segments[0]["segment"]["end"]},
|
161 |
+
}
|
162 |
+
|
163 |
+
for i in range(1, len(raw_segments)):
|
164 |
+
next_segment = raw_segments[i]
|
165 |
+
|
166 |
+
# If the speaker changes
|
167 |
+
if next_segment["label"] != current_speaker_segment["speaker"]:
|
168 |
+
# Add the accumulated segment for the previous speaker
|
169 |
+
combined_segments.append(current_speaker_segment)
|
170 |
+
# Start a new segment accumulation with the current speaker
|
171 |
+
current_speaker_segment = {
|
172 |
+
"speaker": next_segment["label"],
|
173 |
+
"segment": {"start": next_segment["segment"]["start"], "end": next_segment["segment"]["end"]},
|
174 |
+
}
|
175 |
+
else:
|
176 |
+
# Same speaker, extend the end time of the current accumulated segment
|
177 |
+
current_speaker_segment["segment"]["end"] = next_segment["segment"]["end"]
|
178 |
+
|
179 |
+
# Add the very last accumulated segment after the loop finishes
|
180 |
+
combined_segments.append(current_speaker_segment)
|
181 |
+
|
182 |
+
return combined_segments
|
183 |
+
|
184 |
+
def post_process_segments_and_transcripts(combined_diarization_segments: list, asr_transcript_chunks: list) -> list:
|
185 |
+
"""
|
186 |
+
Aligns combined diarization segments with ASR transcript chunks.
|
187 |
+
This logic closely follows the provided `diarization_utils.py`'s `post_process_segments_and_transcripts`
|
188 |
+
function, which uses `argmin` for alignment and slicing for chunk consumption.
|
189 |
+
"""
|
190 |
+
if not asr_transcript_chunks:
|
191 |
+
return []
|
192 |
+
|
193 |
+
# Get the end timestamps for each ASR chunk
|
194 |
+
# Use sys.float_info.max for None to ensure `argmin` works
|
195 |
+
asr_end_timestamps = np.array(
|
196 |
+
[chunk["timestamp"][1] if chunk["timestamp"][1] is not None else sys.float_info.max for chunk in asr_transcript_chunks]
|
197 |
+
)
|
198 |
+
|
199 |
+
# Create mutable copies to slice from
|
200 |
+
current_asr_chunks = list(asr_transcript_chunks)
|
201 |
+
current_asr_end_timestamps = asr_end_timestamps.copy()
|
202 |
+
|
203 |
+
final_segmented_transcript = []
|
204 |
+
|
205 |
+
for diar_segment in combined_diarization_segments:
|
206 |
+
if not current_asr_chunks:
|
207 |
+
break # No more ASR chunks to process
|
208 |
+
|
209 |
+
diar_start = diar_segment["segment"]["start"]
|
210 |
+
diar_end = diar_segment["segment"]["end"]
|
211 |
+
speaker = diar_segment["speaker"]
|
212 |
+
|
213 |
+
# Find the index in `current_asr_end_timestamps` whose value is closest to `diar_end`.
|
214 |
+
# This `upto_idx_relative` determines how many ASR chunks from `current_asr_chunks`
|
215 |
+
# will be associated with the current `diar_segment`.
|
216 |
+
upto_idx_relative = np.argmin(np.abs(current_asr_end_timestamps - diar_end))
|
217 |
+
|
218 |
+
# Select the ASR chunks up to and including this `upto_idx_relative`.
|
219 |
+
chunks_for_this_diar_segment = current_asr_chunks[:upto_idx_relative + 1]
|
220 |
+
|
221 |
+
if not chunks_for_this_diar_segment:
|
222 |
+
continue # No ASR chunks found for this diarization segment, skip
|
223 |
+
|
224 |
+
# Combine the text from the selected ASR chunks.
|
225 |
+
combined_text = "".join([chunk["text"] for chunk in chunks_for_this_diar_segment]).strip()
|
226 |
+
|
227 |
+
# Determine the start and end timestamp for the combined ASR text.
|
228 |
+
# This will be the min start and max end of the involved ASR chunks.
|
229 |
+
asr_min_start = min(chunk["timestamp"][0] for chunk in chunks_for_this_diar_segment if chunk["timestamp"][0] is not None)
|
230 |
+
asr_max_end = max(chunk["timestamp"][1] for chunk in chunks_for_this_diar_segment if chunk["timestamp"][1] is not None)
|
231 |
+
|
232 |
+
# Final timestamp for the output segment should be clamped by the diarization segment's boundaries
|
233 |
+
# to ensure it doesn't extend beyond what the diarizer indicated.
|
234 |
+
final_segment_start = max(diar_start, asr_min_start)
|
235 |
+
final_segment_end = min(diar_end, asr_max_end)
|
236 |
+
|
237 |
+
final_segmented_transcript.append(
|
238 |
+
{
|
239 |
+
"speaker": speaker,
|
240 |
+
"text": combined_text,
|
241 |
+
"timestamp": (final_segment_start, final_segment_end),
|
242 |
+
}
|
243 |
+
)
|
244 |
+
|
245 |
+
# Remove the processed ASR chunks from the lists for the next iteration.
|
246 |
+
current_asr_chunks = current_asr_chunks[upto_idx_relative + 1:]
|
247 |
+
current_asr_end_timestamps = current_asr_end_timestamps[upto_idx_relative + 1:]
|
248 |
+
|
249 |
+
return final_segmented_transcript
|
250 |
+
|
251 |
+
def diarize_and_align_transcript(diarization_pipeline: Pipeline, original_sampling_rate: int,
|
252 |
+
audio_numpy_array: np.ndarray, parameters: InferenceConfig, asr_outputs: dict) -> list:
|
253 |
+
"""
|
254 |
+
Orchestrates the entire diarization and transcript alignment process.
|
255 |
+
"""
|
256 |
+
# 1. Preprocess audio for the diarization model (resample to 16kHz, ensure mono, convert to torch.Tensor)
|
257 |
+
diarizer_input_tensor, processed_sampling_rate = preprocess_audio_for_diarization(
|
258 |
+
original_sampling_rate, audio_numpy_array
|
259 |
+
)
|
260 |
+
|
261 |
+
# 2. Perform diarization to get speaker segments
|
262 |
+
# Update parameters with the processed sampling rate for diarization model's internal use.
|
263 |
+
diarization_params_for_pipeline = parameters.model_copy(update={"sampling_rate": processed_sampling_rate})
|
264 |
+
combined_diarization_segments = diarize_audio(
|
265 |
+
diarizer_input_tensor,
|
266 |
+
diarization_pipeline,
|
267 |
+
diarization_params_for_pipeline
|
268 |
+
)
|
269 |
+
|
270 |
+
# 3. Align diarization segments with ASR transcript chunks
|
271 |
+
aligned_transcript = post_process_segments_and_transcripts(
|
272 |
+
combined_diarization_segments, asr_outputs["chunks"]
|
273 |
+
)
|
274 |
+
|
275 |
+
return aligned_transcript
|
276 |
+
|
277 |
+
# --- Main Prediction Function for Gradio Interface ---
|
278 |
+
def predict_audio(
|
279 |
+
audio_file_tuple: tuple[int, np.ndarray],
|
280 |
+
batch_size: int,
|
281 |
+
chunk_length_s: int,
|
282 |
+
language: str,
|
283 |
+
num_speakers: Optional[int],
|
284 |
+
min_speakers: Optional[int],
|
285 |
+
max_speakers: Optional[int]
|
286 |
+
) -> tuple[str, str, str]:
|
287 |
+
"""
|
288 |
+
Gradio-compatible function to perform ASR and optionally speaker diarization.
|
289 |
+
|
290 |
+
Args:
|
291 |
+
audio_file_tuple: A tuple (sampling_rate, numpy_array) from Gradio's gr.Audio input.
|
292 |
+
batch_size: Batch size for ASR inference.
|
293 |
+
chunk_length_s: Chunk length for ASR inference in seconds.
|
294 |
+
language: Language for ASR (e.g., "English", "Auto-detect").
|
295 |
+
num_speakers: Expected number of speakers for diarization (optional).
|
296 |
+
min_speakers: Minimum number of speakers for diarization (optional).
|
297 |
+
max_speakers: Maximum number of speakers for diarization (optional).
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
A tuple containing:
|
301 |
+
- formatted_diarized_text: A string with the diarized transcript.
|
302 |
+
- full_transcript_text: A string with the full ASR transcript.
|
303 |
+
- status_message: A message indicating success or failure.
|
304 |
+
"""
|
305 |
+
if audio_file_tuple is None:
|
306 |
+
return "", "", "Please upload an audio file."
|
307 |
+
|
308 |
+
sampling_rate, audio_numpy_array = audio_file_tuple
|
309 |
+
|
310 |
+
if audio_numpy_array is None or audio_numpy_array.size == 0:
|
311 |
+
return "", "", "Audio file is empty. Please upload a valid audio."
|
312 |
+
|
313 |
+
# Ensure audio_numpy_array is float32 as expected by transformers pipeline
|
314 |
+
if audio_numpy_array.dtype != np.float32:
|
315 |
+
audio_numpy_array = audio_numpy_array.astype(np.float32)
|
316 |
+
|
317 |
+
# If stereo, convert to mono for consistent processing (e.g., take the first channel)
|
318 |
+
if len(audio_numpy_array.shape) > 1:
|
319 |
+
audio_numpy_array = audio_numpy_array[:, 0]
|
320 |
+
|
321 |
+
# Create an InferenceConfig object from Gradio inputs for internal validation and use.
|
322 |
+
try:
|
323 |
+
parameters = InferenceConfig(
|
324 |
+
batch_size=batch_size,
|
325 |
+
chunk_length_s=chunk_length_s,
|
326 |
+
language=language if language != "Auto-detect" else None, # Convert "Auto-detect" to None for model
|
327 |
+
num_speakers=num_speakers,
|
328 |
+
min_speakers=min_speakers,
|
329 |
+
max_speakers=max_speakers,
|
330 |
+
)
|
331 |
+
except Exception as e:
|
332 |
+
logger.error(f"Error validating parameters: {e}")
|
333 |
+
return "", "", f"Error validating input parameters: {e}"
|
334 |
+
|
335 |
+
logger.info(f"Inference parameters: {parameters.model_dump_json()}")
|
336 |
+
logger.info(f"Audio sampling rate: {sampling_rate} Hz, Audio shape: {audio_numpy_array.shape}")
|
337 |
+
|
338 |
+
asr_pipeline = models.get("asr_pipeline")
|
339 |
+
diarization_pipeline = models.get("diarization_pipeline")
|
340 |
+
|
341 |
+
if not asr_pipeline:
|
342 |
+
return "", "", "ASR model not loaded. Please restart the application."
|
343 |
+
|
344 |
+
# Prepare ASR generation arguments
|
345 |
+
generate_kwargs = {
|
346 |
+
"task": parameters.task,
|
347 |
+
"language": parameters.language,
|
348 |
+
"assistant_model": None # Speculative decoding is disabled
|
349 |
+
}
|
350 |
+
|
351 |
+
asr_outputs = None
|
352 |
+
try:
|
353 |
+
logger.info("Starting ASR inference...")
|
354 |
+
asr_outputs = asr_pipeline(
|
355 |
+
audio_numpy_array, # Pass numpy array directly
|
356 |
+
chunk_length_s=parameters.chunk_length_s,
|
357 |
+
batch_size=parameters.batch_size,
|
358 |
+
generate_kwargs=generate_kwargs,
|
359 |
+
return_timestamps=True,
|
360 |
+
sampling_rate=sampling_rate # Pass original sampling rate to pipeline
|
361 |
+
)
|
362 |
+
logger.info("ASR inference completed.")
|
363 |
+
except Exception as e:
|
364 |
+
logger.error(f"ASR inference error: {str(e)}")
|
365 |
+
return "", "", f"ASR inference error: {str(e)}"
|
366 |
+
|
367 |
+
final_transcript_data = []
|
368 |
+
status_message = ""
|
369 |
+
|
370 |
+
if diarization_pipeline:
|
371 |
+
try:
|
372 |
+
logger.info("Starting Diarization inference and alignment...")
|
373 |
+
final_transcript_data = diarize_and_align_transcript(
|
374 |
+
diarization_pipeline, sampling_rate, audio_numpy_array, parameters, asr_outputs
|
375 |
+
)
|
376 |
+
status_message = "Diarization and ASR successful!"
|
377 |
+
logger.info("Diarization and alignment completed.")
|
378 |
+
except Exception as e:
|
379 |
+
logger.error(f"Diarization inference error: {str(e)}")
|
380 |
+
# If diarization fails, still provide the full ASR transcript
|
381 |
+
final_transcript_data = [] # Clear any partial diarization
|
382 |
+
status_message = f"Diarization failed: {str(e)}. Displaying full ASR transcript only."
|
383 |
+
else:
|
384 |
+
logger.info("Diarization pipeline not loaded, skipping diarization and returning raw ASR chunks.")
|
385 |
+
# If no diarization, format ASR chunks as if they were from a single "Speaker"
|
386 |
+
for chunk in asr_outputs["chunks"]:
|
387 |
+
final_transcript_data.append({
|
388 |
+
"speaker": "Speaker", # Generic label
|
389 |
+
"text": chunk["text"],
|
390 |
+
"timestamp": chunk["timestamp"]
|
391 |
+
})
|
392 |
+
status_message = "Diarization not enabled. Displaying full ASR transcript by chunk."
|
393 |
+
|
394 |
+
# Format the output for Gradio display
|
395 |
+
formatted_diarized_text_output = []
|
396 |
+
for entry in final_transcript_data:
|
397 |
+
start_time = f"{entry['timestamp'][0]:.2f}" if entry['timestamp'][0] is not None else "0.00"
|
398 |
+
end_time = f"{entry['timestamp'][1]:.2f}" if entry['timestamp'][1] is not None else "End"
|
399 |
+
formatted_diarized_text_output.append(
|
400 |
+
f"[{start_time} - {end_time}] {entry['speaker']}: {entry['text'].strip()}"
|
401 |
+
)
|
402 |
+
|
403 |
+
full_asr_text_output = asr_outputs["text"] if asr_outputs else "No ASR transcript generated."
|
404 |
+
|
405 |
+
return (
|
406 |
+
"\n".join(formatted_diarized_text_output),
|
407 |
+
full_asr_text_output,
|
408 |
+
status_message
|
409 |
+
)
|
410 |
+
|
411 |
+
# --- Gradio Interface Definition ---
|
412 |
+
|
413 |
+
# List of languages supported by OpenAI Whisper models
|
414 |
+
WHISPER_LANGUAGES = [
|
415 |
+
"Auto-detect", "English", "Chinese", "German", "Spanish", "Russian", "Korean", "French", "Japanese", "Portuguese",
|
416 |
+
"Turkish", "Polish", "Catalan", "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi", "Finnish",
|
417 |
+
"Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay", "Czech", "Romanian", "Danish", "Hungarian", "Tamil",
|
418 |
+
"Norwegian", "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin", "Maori", "Malayalam", "Afrikaans",
|
419 |
+
"Welsh", "Belarusian", "Gujarati", "Kannada", "Armenian", "Azerbaijani", "Serbian", "Slovenian", "Estonian",
|
420 |
+
"Burmese", "Galician", "Mongolian", "Lao", "Kazakh", "Georgian", "Amharic", "Nepali", "Bosnian", "Luxembourgish",
|
421 |
+
"Pashto", "Tagalog", "Malagasy", "Albanian", "Sindhi", "Kurdish", "Somali", "Telugu", "Tajik", "Swahili",
|
422 |
+
"Kashmiri"
|
423 |
+
]
|
424 |
+
|
425 |
+
demo = gr.Interface(
|
426 |
+
fn=predict_audio,
|
427 |
+
inputs=[
|
428 |
+
gr.Audio(type="numpy", label="Upload Audio File (WAV, MP3, FLAC, etc.)"),
|
429 |
+
gr.Slider(minimum=1, maximum=32, value=24, step=1, label="ASR Batch Size"),
|
430 |
+
gr.Slider(minimum=1, maximum=60, value=30, step=1, label="ASR Chunk Length (seconds)"),
|
431 |
+
gr.Dropdown(WHISPER_LANGUAGES, value="Auto-detect", label="ASR Language"),
|
432 |
+
gr.Number(label="Diarization: Number of Speakers (optional)", value=None, precision=0, info="Expected total number of speakers."),
|
433 |
+
gr.Number(label="Diarization: Min Speakers (optional)", value=None, precision=0, info="Minimum number of speakers to detect."),
|
434 |
+
gr.Number(label="Diarization: Max Speakers (optional)", value=None, precision=0, info="Maximum number of speakers to detect.")
|
435 |
+
],
|
436 |
+
outputs=[
|
437 |
+
gr.Textbox(label="Diarized Transcript", lines=10, interactive=False),
|
438 |
+
gr.Textbox(label="Full ASR Transcript", lines=5, interactive=False),
|
439 |
+
gr.Textbox(label="Status Message", lines=1, interactive=False)
|
440 |
+
],
|
441 |
+
title="Whisper ASR with Pyannote Speaker Diarization",
|
442 |
+
description=(
|
443 |
+
"Upload an audio file to get a transcript with speaker diarization. "
|
444 |
+
"This demo uses `openai/whisper-small` for ASR and `pyannote/speaker-diarization-3.1` for diarization. "
|
445 |
+
"A Hugging Face token with access to `pyannote/speaker-diarization-3.1` is required. "
|
446 |
+
"Please set it as an `HF_TOKEN` environment variable before launching (see script comments)."
|
447 |
+
"<br><b>Note:</b> For long audios or high concurrent usage, consider using a GPU and models like `whisper-large-v3`."
|
448 |
+
),
|
449 |
+
allow_flagging="never", # Disable Gradio flagging feature
|
450 |
+
# Example audio path assumes you are running from the cloned repository root.
|
451 |
+
# If not, download a small WAV file (e.g., from Common Voice) and update this path.
|
452 |
+
examples=[
|
453 |
+
[os.path.join(os.path.dirname(__file__), "model-server", "app", "tests", "polyai-minds14-0.wav"), 24, 30, "Auto-detect", None, None, None]
|
454 |
+
]
|
455 |
+
)
|
456 |
+
|
457 |
+
if __name__ == "__main__":
|
458 |
+
demo.launch()
|