from pydub import AudioSegment import os import torchaudio import torch import re import whisper_timestamped as whisper_ts from faster_whisper import WhisperModel from settings import DEBUG_MODE, MODEL_PATH_V2_FAST, MODEL_PATH_V2, LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH, FAKE_AUDIO_PATH, RESAMPLING_FREQ import time def get_settings(): if DEBUG_MODE: print(f"Entering get_settings function...") # HACK hardcoding this to try is_cuda_available = True #torch.cuda.is_available() if is_cuda_available: device = "cuda" compute_type = "float16" else: device = "cpu" compute_type = "int8" if DEBUG_MODE: print(f"is_cuda_available: {is_cuda_available}") if DEBUG_MODE: print(f"device: {device}") if DEBUG_MODE: print(f"compute_type: {compute_type}") if DEBUG_MODE: print(f"Exited get_settings function.") return device, compute_type def load_model(use_v2_fast, device, compute_type): if DEBUG_MODE: print(f"Entering load_model function...") if DEBUG_MODE: print(f"use_v2_fast: {use_v2_fast}") if use_v2_fast: if DEBUG_MODE: print(f"Loading {MODEL_PATH_V2_FAST} using {device} with {compute_type}...") model = WhisperModel( MODEL_PATH_V2_FAST, device = device, compute_type = compute_type, ) else: if DEBUG_MODE: print(f"Loading {MODEL_PATH_V2} using {device} with {compute_type}...") # TODO add compute_type to load model model = whisper_ts.load_model( MODEL_PATH_V2, device = device, ) # HACK we need to do this for strange reasons. # If we don't do this, we get: #Could not load library libcudnn_ops_infer.so.8. Error: libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory fake_model = whisper_ts.load_model(MODEL_PATH_V2, device=device) if DEBUG_MODE: print(f"Exiting load_model function...") return model, fake_model def split_input_stereo_channels(audio_path): if DEBUG_MODE: print(f"Entering split_input_stereo_channels function...") ext = os.path.splitext(audio_path)[1].lower() if ext == ".wav": audio = AudioSegment.from_wav(audio_path) elif ext == ".mp3": audio = AudioSegment.from_file(audio_path, format="mp3") else: raise ValueError(f"Unsupported file format for: {audio_path}") channels = audio.split_to_mono() if len(channels) != 2: raise ValueError(f"Audio {audio_path} has {len(channels)} channels (instead of 2).") channels[0].export(RIGHT_CHANNEL_TEMP_PATH, format="wav") # Right channels[1].export(LEFT_CHANNEL_TEMP_PATH, format="wav") # Left if DEBUG_MODE: print(f"Exited split_input_stereo_channels function.") def format_audio(audio_path): if DEBUG_MODE: print(f"Entering format_audio function...") input_audio, sample_rate = torchaudio.load(audio_path) if input_audio.shape[0] == 2: input_audio = torch.mean(input_audio, dim=0, keepdim=True) resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=RESAMPLING_FREQ) input_audio = resampler(input_audio) input_audio = input_audio.squeeze() if DEBUG_MODE: print(f"Exited format_audio function.") return input_audio, RESAMPLING_FREQ def process_waveforms(): if DEBUG_MODE: print(f"Entering process_waveforms function...") left_waveform, _ = format_audio(LEFT_CHANNEL_TEMP_PATH) right_waveform, _ = format_audio(RIGHT_CHANNEL_TEMP_PATH) # TODO should this be equal to compute_type? left_waveform = left_waveform.numpy().astype("float16") right_waveform = right_waveform.numpy().astype("float16") if DEBUG_MODE: print(f"Exited process_waveforms function.") return left_waveform, right_waveform def transcribe_audio_no_fast_model(model, audio_path): if DEBUG_MODE: print(f"Entering transcribe_audio_no_fast_model function...") result = whisper_ts.transcribe( model, audio_path, beam_size=5, best_of=5, temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), vad=False, detect_disfluencies=True, ) words = [] for segment in result.get('segments', []): for word in segment.get('words', []): word_text = word.get('word', '').strip() if word_text.startswith(' '): word_text = word_text[1:] words.append({ 'word': word_text, 'start': word.get('start', 0), 'end': word.get('end', 0), 'confidence': word.get('confidence', 0) }) return { 'audio_path': audio_path, 'text': result['text'].strip(), 'segments': result.get('segments', []), 'words': words, 'duration': result.get('duration', 0), 'success': True } if DEBUG_MODE: print(f"Exited transcribe_audio_no_fast_model function.") def transcribe_channels(left_waveform, right_waveform, model, use_v2_fast, fake_model): if DEBUG_MODE: print(f"Entering transcribe_channels function...") # HACK we need to do this for strange reasons. # If we don't do this, we get: #Could not load library libcudnn_ops_infer.so.8. Error: libcudnn_ops_infer.so.8: cannot open shared object file: No such file or directory fake_result = whisper_ts.transcribe( fake_model, FAKE_AUDIO_PATH, beam_size=1, ) if DEBUG_MODE: print(f"Preparing to transcribe...") if use_v2_fast: left_result, _ = model.transcribe(left_waveform, beam_size=5, task="transcribe") right_result, _ = model.transcribe(right_waveform, beam_size=5, task="transcribe") left_result = list(left_result) right_result = list(right_result) else: left_result = transcribe_audio_no_fast_model(model, left_waveform) right_result = transcribe_audio_no_fast_model(model, right_waveform) if DEBUG_MODE: print(f"Exited transcribe_channels function.") return left_result, right_result # TODO refactor and rename this function def post_process_transcription(transcription, max_repeats=2): tokens = re.findall(r'\b\w+\'?\w*\b[.,!?]?', transcription) cleaned_tokens = [] repetition_count = 0 previous_token = None for token in tokens: reduced_token = re.sub(r"(\w{1,3})(\1{2,})", "", token) if reduced_token == previous_token: repetition_count += 1 if repetition_count <= max_repeats: cleaned_tokens.append(reduced_token) else: repetition_count = 1 cleaned_tokens.append(reduced_token) previous_token = reduced_token cleaned_transcription = " ".join(cleaned_tokens) cleaned_transcription = re.sub(r'\s+', ' ', cleaned_transcription).strip() return cleaned_transcription # TODO not used right now, decide to use it or not def post_merge_consecutive_segments_from_text(transcription_text: str) -> str: segments = re.split(r'(\[SPEAKER_\d{2}\])', transcription_text) merged_transcription = '' current_speaker = None current_segment = [] for i in range(1, len(segments) - 1, 2): speaker_tag = segments[i] text = segments[i + 1].strip() speaker = re.search(r'\d{2}', speaker_tag).group() if speaker == current_speaker: current_segment.append(text) else: if current_speaker is not None: merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' current_speaker = speaker current_segment = [text] if current_speaker is not None: merged_transcription += f'[SPEAKER_{current_speaker}] {" ".join(current_segment)}\n' return merged_transcription.strip() def get_segments(result, speaker_label, use_v2_fast): if DEBUG_MODE: print(f"Entering get_segments function...") if use_v2_fast: segments = result final_segments = [ (seg.start, seg.end, speaker_label, post_process_transcription(seg.text.strip())) for seg in segments if seg.text ] else: segments = result.get("segments", []) if not segments: final_segments = [] final_segments = [ (seg.get("start", 0.0), seg.get("end", 0.0), speaker_label, post_process_transcription(seg.get("text", "").strip())) for seg in segments if seg.get("text") ] if DEBUG_MODE: print(f"EXited get_segments function.") return final_segments def post_process_transcripts(left_result, right_result, use_v2_fast): if DEBUG_MODE: print(f"Entering post_process_transcripts function...") left_segs = get_segments(left_result, "Speaker 1", use_v2_fast) right_segs = get_segments(right_result, "Speaker 2", use_v2_fast) merged_transcript = sorted( left_segs + right_segs, key=lambda x: float(x[0]) if x[0] is not None else float("inf") ) clean_output = "" for start, end, speaker, text in merged_transcript: clean_output += f"[{speaker}]: {text}\n" clean_output = clean_output.strip() if DEBUG_MODE: print(f"Exited post_process_transcripts function.") return clean_output def cleanup_temp_files(*file_paths): if DEBUG_MODE: print(f"Entered cleanup_temp_files function...") if DEBUG_MODE: print(f"File paths to remove: {file_paths}") for path in file_paths: if path and os.path.exists(path): if DEBUG_MODE: print(f"Removing path: {path}") os.remove(path) if DEBUG_MODE: print(f"Exited cleanup_temp_files function.") def generate(audio_path, use_v2_fast): if DEBUG_MODE: print(f"Entering generate function...") start = time.time() device, compute_type = get_settings() model, fake_model = load_model(use_v2_fast, device, compute_type) split_input_stereo_channels(audio_path) left_waveform, right_waveform = process_waveforms() left_result, right_result = transcribe_channels(left_waveform, right_waveform, model, use_v2_fast, fake_model) output = post_process_transcripts(left_result, right_result, use_v2_fast) cleanup_temp_files(LEFT_CHANNEL_TEMP_PATH, RIGHT_CHANNEL_TEMP_PATH) end = time.time() elapsed_secs = end - start if DEBUG_MODE: print(f"elapsed_secs: {elapsed_secs}") if DEBUG_MODE: print(f"Exited generate function.") return output