import spaces from pydub import AudioSegment import os import torchaudio import torch import re import whisper_timestamped as whisper_ts from typing import Dict from faster_whisper import WhisperModel device = 0 if torch.cuda.is_available() else "cpu" torch_dtype = torch.float32 MODEL_PATH_V2 = "langtech-veu/whisper-timestamped-cs" MODEL_PATH_V2_FAST = "langtech-veu/faster-whisper-timestamped-cs" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print("[INFO] CUDA available:", torch.cuda.is_available()) def clean_text(input_text): remove_chars = ['.', ',', ';', ':', '¿', '?', '«', '»', '-', '¡', '!', '@', '*', '{', '}', '[', ']', '=', '/', '\\', '&', '#', '…'] output_text = ''.join(char if char not in remove_chars else ' ' for char in input_text) return ' '.join(output_text.split()).lower() def split_stereo_channels(audio_path): ext = os.path.splitext(audio_path)[1].lower() if ext == ".wav": audio = AudioSegment.from_wav(audio_path) elif ext == ".mp3": audio = AudioSegment.from_file(audio_path, format="mp3") else: raise ValueError(f"Unsupported file format: {audio_path}") channels = audio.split_to_mono() if len(channels) != 2: raise ValueError(f"Audio {audio_path} does not have 2 channels.") channels[0].export(f"temp_mono_speaker1.wav", format="wav") # Right channels[1].export(f"temp_mono_speaker2.wav", format="wav") # Left def format_audio(audio_path): input_audio, sample_rate = torchaudio.load(audio_path) if input_audio.shape[0] == 2: input_audio = torch.mean(input_audio, dim=0, keepdim=True) resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) input_audio = resampler(input_audio) return input_audio.squeeze(), 16000 def post_process_transcription(transcription, max_repeats=2): tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) cleaned_tokens = [] repetition_count = 0 previous_token = None for token in tokens: reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) if reduced_token == previous_token: repetition_count += 1 if repetition_count <= max_repeats: cleaned_tokens.append(reduced_token) else: repetition_count = 1 cleaned_tokens.append(reduced_token) previous_token = reduced_token cleaned_transcription = " ".join(cleaned_tokens) cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() return cleaned_transcription def post_merge_consecutive_segments_from_text(transcription_text: str) -> str: segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text) merged_transcription = '' current_speaker = None current_segment = [] for i in range(1, len(segments) - 1, 2): speaker_tag = segments[i] text = segments[i + 1].strip() speaker = re.search(r'\d{2}', speaker_tag).group() if speaker == current_speaker: current_segment.append(text) else: if current_speaker is not None: merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' current_speaker = speaker current_segment = [text] if current_speaker is not None: merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' return merged_transcription.strip() def cleanup_temp_files(*file_paths): for path in file_paths: if path and os.path.exists(path): os.remove(path) try: faster_model = WhisperModel( MODEL_PATH_V2_FAST, device="cuda" if torch.cuda.is_available() else "cpu", compute_type="float16" if torch.cuda.is_available() else "int8" ) except RuntimeError as e: print(f"[WARNING] Failed to load model on GPU: {e}") faster_model = WhisperModel( MODEL_PATH_V2_FAST, device="cpu", compute_type="int8" ) def load_whisper_model(model_path: str): device = "cuda" if torch.cuda.is_available() else "cpu" model = whisper_ts.load_model(model_path, device=device) return model def transcribe_audio(model, audio_path: str) -> Dict: try: result = whisper_ts.transcribe( model, audio_path, beam_size=5, best_of=5, temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), vad=False, detect_disfluencies=True, ) words = [] for segment in result.get('segments', []): for word in segment.get('words', []): word_text = word.get('word', '').strip() if word_text.startswith(' '): word_text = word_text[1:] words.append({ 'word': word_text, 'start': word.get('start', 0), 'end': word.get('end', 0), 'confidence': word.get('confidence', 0) }) return { 'audio_path': audio_path, 'text': result['text'].strip(), 'segments': result.get('segments', []), 'words': words, 'duration': result.get('duration', 0), 'success': True } except Exception as e: return { 'audio_path': audio_path, 'error': str(e), 'success': False } diarization_pipeline = DiarizationPipeline.from_pretrained("./pyannote/config.yaml") align_model, metadata = whisperx.load_align_model(language_code="en", device=DEVICE) asr_pipe = pipeline( task="automatic-speech-recognition", model=MODEL_PATH_1, chunk_length_s=30, device=DEVICE, return_timestamps=True) def diarization(audio_path): diarization_result = diarization_pipeline(audio_path) diarized_segments = list(diarization_result.itertracks(yield_label=True)) print('diarized_segments',diarized_segments) return diarized_segments def asr(audio_path): print(f"[DEBUG] Starting ASR on audio: {audio_path}") asr_result = asr_pipe(audio_path, return_timestamps=True) print(f"[DEBUG] Raw ASR result: {asr_result}") asr_segments = hf_chunks_to_whisperx_segments(asr_result['chunks']) asr_segments = assign_timestamps(asr_segments, audio_path) return asr_segments def generate(audio_path, use_v2_fast): if use_v2_fast: left_channel_path = "temp_mono_speaker2.wav" right_channel_path = "temp_mono_speaker1.wav" left_waveform, left_sr = format_audio(left_channel_path) right_waveform, right_sr = format_audio(right_channel_path) left_waveform = left_waveform.numpy().astype("float32") right_waveform = right_waveform.numpy().astype("float32") left_result, info = faster_model.transcribe(left_waveform, beam_size=5, task="transcribe") right_result, info = faster_model.transcribe(right_waveform, beam_size=5, task="transcribe") left_result = list(left_result) right_result = list(right_result) def get_faster_segments(segments, speaker_label): return [ (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip())) for seg in segments if seg.text ] left_segs = get_faster_segments(left_result, "Speaker 1") right_segs = get_faster_segments(right_result, "Speaker 2") merged_transcript = sorted( left_segs + right_segs, key=lambda x: float(x[0]) if x[0] is not None else float("inf") ) clean_output = "" for start, end, speaker, text in merged_transcript: clean_output += f"[{speaker}]: {text}\n" clean_output = post_merge_consecutive_segments_from_text(clean_output) else: model = load_whisper_model(MODEL_PATH_V2) split_stereo_channels(audio_path) left_channel_path = "temp_mono_speaker2.wav" right_channel_path = "temp_mono_speaker1.wav" left_waveform, left_sr = format_audio(left_channel_path) right_waveform, right_sr = format_audio(right_channel_path) left_result = transcribe_audio(model, left_waveform) right_result = transcribe_audio(model, right_waveform) def get_segments(result, speaker_label): segments = result.get("segments", []) if not segments: return [] return [ (seg.get("start", 0.0), seg.get("end", 0.0), speaker_label, post_process_transcription(seg.get("text", "").strip())) for seg in segments if seg.get("text") ] left_segs = get_segments(left_result, "Speaker 1") right_segs = get_segments(right_result, "Speaker 2") merged_transcript = sorted( left_segs + right_segs, key=lambda x: float(x[0]) if x[0] is not None else float("inf") ) output = "" for start, end, speaker, text in merged_transcript: output += f"[{speaker}]: {text}\n" clean_output = output.strip() cleanup_temp_files( "temp_mono_speaker1.wav", "temp_mono_speaker2.wav" ) return clean_output