Spaces:
Build error
Build error
| import whisperx | |
| import torch | |
| import numpy as np | |
| from scipy.signal import resample | |
| from pyannote.audio import Pipeline | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import logging | |
| import time | |
| from difflib import SequenceMatcher | |
| import spaces | |
| hf_token = os.getenv("HF_TOKEN") | |
| CHUNK_LENGTH = 5 | |
| OVERLAP = 2 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables for models | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| compute_type = "float16" if device == "cuda" else "float32" | |
| whisper_model = None | |
| diarization_pipeline = None | |
| def load_models(model_size="small"): | |
| global whisper_model, diarization_pipeline | |
| # Load Whisper model | |
| whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type) | |
| # Try to initialize diarization pipeline | |
| try: | |
| diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token) | |
| diarization_pipeline = diarization_pipeline.to(torch.device(device)) | |
| except Exception as e: | |
| logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.") | |
| diarization_pipeline = None | |
| def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000): | |
| chunks = [] | |
| for i in range(0, len(audio), chunk_size - overlap): | |
| chunk = audio[i:i+chunk_size] | |
| if len(chunk) < chunk_size: | |
| chunk = np.pad(chunk, (0, chunk_size - len(chunk))) | |
| chunks.append(chunk) | |
| return chunks | |
| def process_audio(audio_file, translate=False, model_size="small"): | |
| global whisper_model, diarization_pipeline | |
| if whisper_model is None: | |
| load_models(model_size) | |
| start_time = time.time() | |
| try: | |
| audio = whisperx.load_audio(audio_file) | |
| # Perform diarization if pipeline is available | |
| diarization_result = None | |
| if diarization_pipeline is not None: | |
| try: | |
| diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000}) | |
| except Exception as e: | |
| logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.") | |
| chunks = preprocess_audio(audio) | |
| language_segments = [] | |
| final_segments = [] | |
| overlap_duration = 2 # 2 seconds overlap | |
| for i, chunk in enumerate(chunks): | |
| chunk_start_time = i * (CHUNK_LENGTH - overlap_duration) | |
| chunk_end_time = chunk_start_time + CHUNK_LENGTH | |
| logger.info(f"Processing chunk {i+1}/{len(chunks)}") | |
| lang = whisper_model.detect_language(chunk) | |
| result_transcribe = whisper_model.transcribe(chunk, language=lang) | |
| if translate: | |
| result_translate = whisper_model.transcribe(chunk, task="translate") | |
| chunk_start_time = i * (CHUNK_LENGTH - overlap_duration) | |
| for j, t_seg in enumerate(result_transcribe["segments"]): | |
| segment_start = chunk_start_time + t_seg["start"] | |
| segment_end = chunk_start_time + t_seg["end"] | |
| # Skip segments in the overlapping region of the previous chunk | |
| if i > 0 and segment_end <= chunk_start_time + overlap_duration: | |
| print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}") | |
| continue | |
| # Skip segments in the overlapping region of the next chunk | |
| if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration: | |
| print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}") | |
| continue | |
| speaker = "Unknown" | |
| if diarization_result is not None: | |
| speakers = [] | |
| for turn, track, spk in diarization_result.itertracks(yield_label=True): | |
| if turn.start <= segment_end and turn.end >= segment_start: | |
| speakers.append(spk) | |
| speaker = max(set(speakers), key=speakers.count) if speakers else "Unknown" | |
| segment = { | |
| "start": segment_start, | |
| "end": segment_end, | |
| "language": lang, | |
| "speaker": speaker, | |
| "text": t_seg["text"], | |
| } | |
| if translate: | |
| segment["translated"] = result_translate["segments"][j]["text"] | |
| final_segments.append(segment) | |
| language_segments.append({ | |
| "language": lang, | |
| "start": chunk_start_time, | |
| "end": chunk_start_time + CHUNK_LENGTH | |
| }) | |
| chunk_end_time = time.time() | |
| logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds") | |
| final_segments.sort(key=lambda x: x["start"]) | |
| merged_segments = merge_nearby_segments(final_segments) | |
| end_time = time.time() | |
| logger.info(f"Total processing time: {end_time - start_time:.2f} seconds") | |
| return language_segments, merged_segments | |
| except Exception as e: | |
| logger.error(f"An error occurred during audio processing: {str(e)}") | |
| raise | |
| # The merge_nearby_segments and print_results functions remain unchanged |