Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
import gradio as gr | |
import logging | |
import sys | |
from typing import Optional, Literal | |
from pydantic import BaseModel | |
from transformers import pipeline | |
from pyannote.audio import Pipeline | |
from huggingface_hub import HfApi | |
from torchaudio import functional as F # For resampling and audio processing | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Configuration --- | |
# You will need a Hugging Face token for pyannote/speaker-diarization-3.1. | |
# 1. Go to https://huggingface.co/settings/tokens to create a new token. | |
# 2. Make sure you have accepted the user conditions on the model page: | |
# https://huggingface.co/pyannote/speaker-diarization-3.1 | |
# 3. Set your token as an environment variable before running this script: | |
# export HF_TOKEN="hf_YOUR_TOKEN_HERE" | |
# Alternatively, replace os.getenv("HF_TOKEN") with your actual token string: | |
# HF_TOKEN = "hf_YOUR_TOKEN_HERE" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Model names | |
ASR_MODEL = "openai/whisper-small" # Smaller, faster Whisper model for demo | |
DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1" | |
# Speculative decoding (assistant model) is explicitly excluded as per requirements. | |
# --- Inference Configuration (Pydantic Model for validation) --- | |
class InferenceConfig(BaseModel): | |
task: Literal["transcribe", "translate"] = "transcribe" | |
batch_size: int = 24 | |
chunk_length_s: int = 30 | |
language: Optional[str] = None | |
num_speakers: Optional[int] = None | |
min_speakers: Optional[int] = None | |
max_speakers: Optional[int] = None | |
# --- Global Models and Device --- | |
models = {} | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
logger.info(f"Using device: {device.type}") | |
torch_dtype = torch.float32 if device.type == "cpu" else torch.float16 # Use float16 on GPU for efficiency | |
# --- Model Loading Function --- | |
def load_models(): | |
""" | |
Loads the ASR and Diarization models into the global `models` dictionary. | |
Handles device placement and Hugging Face token authentication. | |
""" | |
logger.info("Loading ASR pipeline...") | |
# The ASR pipeline can directly take a numpy array for inference. | |
models["asr_pipeline"] = pipeline( | |
"automatic-speech-recognition", | |
model=ASR_MODEL, | |
torch_dtype=torch_dtype, | |
device=device | |
) | |
logger.info("ASR pipeline loaded.") | |
if DIARIZATION_MODEL: | |
logger.info(f"Loading Diarization pipeline: {DIARIZATION_MODEL}...") | |
if not HF_TOKEN: | |
raise ValueError( | |
"HF_TOKEN environment variable or HF_TOKEN constant not set. " | |
"Pyannote models require a Hugging Face token for authentication. " | |
"Get it from https://huggingface.co/settings/tokens and ensure you accept " | |
"the user conditions on the model page: " | |
"https://huggingface.co/pyannote/speaker-diarization-3.1" | |
) | |
try: | |
# Verify token and load pyannote pipeline | |
HfApi().whoami(token=HF_TOKEN) # Check token validity | |
models["diarization_pipeline"] = Pipeline.from_pretrained( | |
checkpoint_path=DIARIZATION_MODEL, | |
use_auth_token=HF_TOKEN, | |
) | |
models["diarization_pipeline"].to(device) | |
logger.info("Diarization pipeline loaded.") | |
except Exception as e: | |
logger.error(f"Failed to load diarization pipeline: {e}") | |
raise | |
else: | |
models["diarization_pipeline"] = None | |
logger.info("Diarization model not specified, diarization will be skipped.") | |
# Load models once when the script starts | |
try: | |
load_models() | |
except Exception as e: | |
logger.critical(f"Failed to load models. Please check your HF_TOKEN and model names. Exiting: {e}") | |
sys.exit(1) | |
# --- Diarization Utility Functions (adapted from original `model-server/app/utils/diarization_utils.py`) --- | |
def preprocess_audio_for_diarization(sampling_rate_in: int, audio_array_in: np.ndarray) -> tuple[torch.Tensor, int]: | |
""" | |
Preprocesses audio for the diarization pipeline. | |
Resamples to 16kHz and ensures single channel float32 torch tensor. | |
""" | |
if audio_array_in is None or audio_array_in.size == 0: | |
raise ValueError("Audio array is empty for diarization preprocessing.") | |
# Convert to float32 if not already (pyannote expects float32) | |
if audio_array_in.dtype != np.float32: | |
audio_array_in = audio_array_in.astype(np.float32) | |
# If stereo, take one channel (pyannote expects single channel) | |
if len(audio_array_in.shape) > 1: | |
audio_array_in = audio_array_in[:, 0] # Take the first channel | |
# Resample to 16kHz if necessary, as pyannote models are typically trained on 16kHz audio. | |
if sampling_rate_in != 16000: | |
audio_array_in = F.resample( | |
torch.from_numpy(audio_array_in), sampling_rate_in, 16000 | |
).numpy() | |
sampling_rate_in = 16000 # Update SR to reflect resampling | |
# Diarization model expects float32 torch tensor of shape `(channels, seq_len)` | |
diarizer_inputs = torch.from_numpy(audio_array_in).float() | |
diarizer_inputs = diarizer_inputs.unsqueeze(0) # Add channel dimension (1, seq_len) | |
return diarizer_inputs, sampling_rate_in | |
def diarize_audio(diarizer_inputs: torch.Tensor, diarization_pipeline: Pipeline, parameters: InferenceConfig) -> list: | |
""" | |
Performs diarization using the pyannote pipeline and combines consecutive speaker segments. | |
""" | |
# Run the diarization pipeline | |
diarization = diarization_pipeline( | |
{"waveform": diarizer_inputs, "sample_rate": 16000}, # Always pass 16kHz to diarizer | |
num_speakers=parameters.num_speakers, | |
min_speakers=parameters.min_speakers, | |
max_speakers=parameters.max_speakers, | |
) | |
raw_segments = [] | |
# pyannote.audio returns segments as `Segment(start=X, end=Y)` | |
for segment, _, label in diarization.itertracks(yield_label=True): | |
raw_segments.append( | |
{ | |
"segment": {"start": segment.start, "end": segment.end}, | |
"label": label, | |
} | |
) | |
# Combine consecutive segments from the same speaker | |
combined_segments = [] | |
if not raw_segments: | |
return combined_segments | |
# Initialize with the first segment | |
current_speaker_segment = { | |
"speaker": raw_segments[0]["label"], | |
"segment": {"start": raw_segments[0]["segment"]["start"], "end": raw_segments[0]["segment"]["end"]}, | |
} | |
for i in range(1, len(raw_segments)): | |
next_segment = raw_segments[i] | |
# If the speaker changes | |
if next_segment["label"] != current_speaker_segment["speaker"]: | |
# Add the accumulated segment for the previous speaker | |
combined_segments.append(current_speaker_segment) | |
# Start a new segment accumulation with the current speaker | |
current_speaker_segment = { | |
"speaker": next_segment["label"], | |
"segment": {"start": next_segment["segment"]["start"], "end": next_segment["segment"]["end"]}, | |
} | |
else: | |
# Same speaker, extend the end time of the current accumulated segment | |
current_speaker_segment["segment"]["end"] = next_segment["segment"]["end"] | |
# Add the very last accumulated segment after the loop finishes | |
combined_segments.append(current_speaker_segment) | |
return combined_segments | |
def post_process_segments_and_transcripts(combined_diarization_segments: list, asr_transcript_chunks: list) -> list: | |
""" | |
Aligns combined diarization segments with ASR transcript chunks. | |
This logic closely follows the provided `diarization_utils.py`'s `post_process_segments_and_transcripts` | |
function, which uses `argmin` for alignment and slicing for chunk consumption. | |
""" | |
if not asr_transcript_chunks: | |
return [] | |
# Get the end timestamps for each ASR chunk | |
# Use sys.float_info.max for None to ensure `argmin` works | |
asr_end_timestamps = np.array( | |
[chunk["timestamp"][1] if chunk["timestamp"][1] is not None else sys.float_info.max for chunk in asr_transcript_chunks] | |
) | |
# Create mutable copies to slice from | |
current_asr_chunks = list(asr_transcript_chunks) | |
current_asr_end_timestamps = asr_end_timestamps.copy() | |
final_segmented_transcript = [] | |
for diar_segment in combined_diarization_segments: | |
if not current_asr_chunks: | |
break # No more ASR chunks to process | |
diar_start = diar_segment["segment"]["start"] | |
diar_end = diar_segment["segment"]["end"] | |
speaker = diar_segment["speaker"] | |
# Find the index in `current_asr_end_timestamps` whose value is closest to `diar_end`. | |
# This `upto_idx_relative` determines how many ASR chunks from `current_asr_chunks` | |
# will be associated with the current `diar_segment`. | |
upto_idx_relative = np.argmin(np.abs(current_asr_end_timestamps - diar_end)) | |
# Select the ASR chunks up to and including this `upto_idx_relative`. | |
chunks_for_this_diar_segment = current_asr_chunks[:upto_idx_relative + 1] | |
if not chunks_for_this_diar_segment: | |
continue # No ASR chunks found for this diarization segment, skip | |
# Combine the text from the selected ASR chunks. | |
combined_text = "".join([chunk["text"] for chunk in chunks_for_this_diar_segment]).strip() | |
# Determine the start and end timestamp for the combined ASR text. | |
# This will be the min start and max end of the involved ASR chunks. | |
asr_min_start = min(chunk["timestamp"][0] for chunk in chunks_for_this_diar_segment if chunk["timestamp"][0] is not None) | |
asr_max_end = max(chunk["timestamp"][1] for chunk in chunks_for_this_diar_segment if chunk["timestamp"][1] is not None) | |
# Final timestamp for the output segment should be clamped by the diarization segment's boundaries | |
# to ensure it doesn't extend beyond what the diarizer indicated. | |
final_segment_start = max(diar_start, asr_min_start) | |
final_segment_end = min(diar_end, asr_max_end) | |
final_segmented_transcript.append( | |
{ | |
"speaker": speaker, | |
"text": combined_text, | |
"timestamp": (final_segment_start, final_segment_end), | |
} | |
) | |
# Remove the processed ASR chunks from the lists for the next iteration. | |
current_asr_chunks = current_asr_chunks[upto_idx_relative + 1:] | |
current_asr_end_timestamps = current_asr_end_timestamps[upto_idx_relative + 1:] | |
return final_segmented_transcript | |
def diarize_and_align_transcript(diarization_pipeline: Pipeline, original_sampling_rate: int, | |
audio_numpy_array: np.ndarray, parameters: InferenceConfig, asr_outputs: dict) -> list: | |
""" | |
Orchestrates the entire diarization and transcript alignment process. | |
""" | |
# 1. Preprocess audio for the diarization model (resample to 16kHz, ensure mono, convert to torch.Tensor) | |
diarizer_input_tensor, processed_sampling_rate = preprocess_audio_for_diarization( | |
original_sampling_rate, audio_numpy_array | |
) | |
# 2. Perform diarization to get speaker segments | |
# Update parameters with the processed sampling rate for diarization model's internal use. | |
diarization_params_for_pipeline = parameters.model_copy(update={"sampling_rate": processed_sampling_rate}) | |
combined_diarization_segments = diarize_audio( | |
diarizer_input_tensor, | |
diarization_pipeline, | |
diarization_params_for_pipeline | |
) | |
# 3. Align diarization segments with ASR transcript chunks | |
aligned_transcript = post_process_segments_and_transcripts( | |
combined_diarization_segments, asr_outputs["chunks"] | |
) | |
return aligned_transcript | |
# --- Main Prediction Function for Gradio Interface --- | |
def predict_audio( | |
audio_file_tuple: tuple[int, np.ndarray], | |
batch_size: int, | |
chunk_length_s: int, | |
language: str, | |
num_speakers: Optional[int], | |
min_speakers: Optional[int], | |
max_speakers: Optional[int] | |
) -> tuple[str, str, str]: | |
""" | |
Gradio-compatible function to perform ASR and optionally speaker diarization. | |
Args: | |
audio_file_tuple: A tuple (sampling_rate, numpy_array) from Gradio's gr.Audio input. | |
batch_size: Batch size for ASR inference. | |
chunk_length_s: Chunk length for ASR inference in seconds. | |
language: Language for ASR (e.g., "English", "Auto-detect"). | |
num_speakers: Expected number of speakers for diarization (optional). | |
min_speakers: Minimum number of speakers for diarization (optional). | |
max_speakers: Maximum number of speakers for diarization (optional). | |
Returns: | |
A tuple containing: | |
- formatted_diarized_text: A string with the diarized transcript. | |
- full_transcript_text: A string with the full ASR transcript. | |
- status_message: A message indicating success or failure. | |
""" | |
if audio_file_tuple is None: | |
return "", "", "Please upload an audio file." | |
sampling_rate, audio_numpy_array = audio_file_tuple | |
if audio_numpy_array is None or audio_numpy_array.size == 0: | |
return "", "", "Audio file is empty. Please upload a valid audio." | |
# Ensure audio_numpy_array is float32 as expected by transformers pipeline | |
if audio_numpy_array.dtype != np.float32: | |
audio_numpy_array = audio_numpy_array.astype(np.float32) | |
# If stereo, convert to mono for consistent processing (e.g., take the first channel) | |
if len(audio_numpy_array.shape) > 1: | |
audio_numpy_array = audio_numpy_array[:, 0] | |
# Create an InferenceConfig object from Gradio inputs for internal validation and use. | |
try: | |
parameters = InferenceConfig( | |
batch_size=batch_size, | |
chunk_length_s=chunk_length_s, | |
language=language if language != "Auto-detect" else None, # Convert "Auto-detect" to None for model | |
num_speakers=num_speakers, | |
min_speakers=min_speakers, | |
max_speakers=max_speakers, | |
) | |
except Exception as e: | |
logger.error(f"Error validating parameters: {e}") | |
return "", "", f"Error validating input parameters: {e}" | |
logger.info(f"Inference parameters: {parameters.model_dump_json()}") | |
logger.info(f"Audio sampling rate: {sampling_rate} Hz, Audio shape: {audio_numpy_array.shape}") | |
asr_pipeline = models.get("asr_pipeline") | |
diarization_pipeline = models.get("diarization_pipeline") | |
if not asr_pipeline: | |
return "", "", "ASR model not loaded. Please restart the application." | |
# Prepare ASR generation arguments | |
generate_kwargs = { | |
"task": parameters.task, | |
"language": parameters.language, | |
"assistant_model": None # Speculative decoding is disabled | |
} | |
asr_outputs = None | |
try: | |
logger.info("Starting ASR inference...") | |
asr_outputs = asr_pipeline( | |
audio_numpy_array, # Pass numpy array directly | |
chunk_length_s=parameters.chunk_length_s, | |
batch_size=parameters.batch_size, | |
generate_kwargs=generate_kwargs, | |
return_timestamps=True, | |
sampling_rate=sampling_rate # Pass original sampling rate to pipeline | |
) | |
logger.info("ASR inference completed.") | |
except Exception as e: | |
logger.error(f"ASR inference error: {str(e)}") | |
return "", "", f"ASR inference error: {str(e)}" | |
final_transcript_data = [] | |
status_message = "" | |
if diarization_pipeline: | |
try: | |
logger.info("Starting Diarization inference and alignment...") | |
final_transcript_data = diarize_and_align_transcript( | |
diarization_pipeline, sampling_rate, audio_numpy_array, parameters, asr_outputs | |
) | |
status_message = "Diarization and ASR successful!" | |
logger.info("Diarization and alignment completed.") | |
except Exception as e: | |
logger.error(f"Diarization inference error: {str(e)}") | |
# If diarization fails, still provide the full ASR transcript | |
final_transcript_data = [] # Clear any partial diarization | |
status_message = f"Diarization failed: {str(e)}. Displaying full ASR transcript only." | |
else: | |
logger.info("Diarization pipeline not loaded, skipping diarization and returning raw ASR chunks.") | |
# If no diarization, format ASR chunks as if they were from a single "Speaker" | |
for chunk in asr_outputs["chunks"]: | |
final_transcript_data.append({ | |
"speaker": "Speaker", # Generic label | |
"text": chunk["text"], | |
"timestamp": chunk["timestamp"] | |
}) | |
status_message = "Diarization not enabled. Displaying full ASR transcript by chunk." | |
# Format the output for Gradio display | |
formatted_diarized_text_output = [] | |
for entry in final_transcript_data: | |
start_time = f"{entry['timestamp'][0]:.2f}" if entry['timestamp'][0] is not None else "0.00" | |
end_time = f"{entry['timestamp'][1]:.2f}" if entry['timestamp'][1] is not None else "End" | |
formatted_diarized_text_output.append( | |
f"[{start_time} - {end_time}] {entry['speaker']}: {entry['text'].strip()}" | |
) | |
full_asr_text_output = asr_outputs["text"] if asr_outputs else "No ASR transcript generated." | |
return ( | |
"\n".join(formatted_diarized_text_output), | |
full_asr_text_output, | |
status_message | |
) | |
# --- Gradio Interface Definition --- | |
# List of languages supported by OpenAI Whisper models | |
WHISPER_LANGUAGES = [ | |
"Auto-detect", "English", "Chinese", "German", "Spanish", "Russian", "Korean", "French", "Japanese", "Portuguese", | |
"Turkish", "Polish", "Catalan", "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi", "Finnish", | |
"Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay", "Czech", "Romanian", "Danish", "Hungarian", "Tamil", | |
"Norwegian", "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin", "Maori", "Malayalam", "Afrikaans", | |
"Welsh", "Belarusian", "Gujarati", "Kannada", "Armenian", "Azerbaijani", "Serbian", "Slovenian", "Estonian", | |
"Burmese", "Galician", "Mongolian", "Lao", "Kazakh", "Georgian", "Amharic", "Nepali", "Bosnian", "Luxembourgish", | |
"Pashto", "Tagalog", "Malagasy", "Albanian", "Sindhi", "Kurdish", "Somali", "Telugu", "Tajik", "Swahili", | |
"Kashmiri" | |
] | |
demo = gr.Interface( | |
fn=predict_audio, | |
inputs=[ | |
gr.Audio(type="numpy", label="Upload Audio File (WAV, MP3, FLAC, etc.)"), | |
gr.Slider(minimum=1, maximum=32, value=24, step=1, label="ASR Batch Size"), | |
gr.Slider(minimum=1, maximum=60, value=30, step=1, label="ASR Chunk Length (seconds)"), | |
gr.Dropdown(WHISPER_LANGUAGES, value="Auto-detect", label="ASR Language"), | |
gr.Number(label="Diarization: Number of Speakers (optional)", value=None, precision=0, info="Expected total number of speakers."), | |
gr.Number(label="Diarization: Min Speakers (optional)", value=None, precision=0, info="Minimum number of speakers to detect."), | |
gr.Number(label="Diarization: Max Speakers (optional)", value=None, precision=0, info="Maximum number of speakers to detect.") | |
], | |
outputs=[ | |
gr.Textbox(label="Diarized Transcript", lines=10, interactive=False), | |
gr.Textbox(label="Full ASR Transcript", lines=5, interactive=False), | |
gr.Textbox(label="Status Message", lines=1, interactive=False) | |
], | |
title="Whisper ASR with Pyannote Speaker Diarization", | |
description=( | |
"Upload an audio file to get a transcript with speaker diarization. " | |
"This demo uses `openai/whisper-small` for ASR and `pyannote/speaker-diarization-3.1` for diarization. " | |
"A Hugging Face token with access to `pyannote/speaker-diarization-3.1` is required. " | |
"Please set it as an `HF_TOKEN` environment variable before launching (see script comments)." | |
"<br><b>Note:</b> For long audios or high concurrent usage, consider using a GPU and models like `whisper-large-v3`." | |
), | |
allow_flagging="never", # Disable Gradio flagging feature | |
# Example audio path assumes you are running from the cloned repository root. | |
# If not, download a small WAV file (e.g., from Common Voice) and update this path. | |
examples=[ | |
[os.path.join(os.path.dirname(__file__), "model-server", "app", "tests", "polyai-minds14-0.wav"), 24, 30, "Auto-detect", None, None, None] | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() |