import os import gc import torch import librosa import numpy as np import gradio as gr from transformers import (AutoProcessor, AutoModelForCTC, AutoModelForTokenClassification, AutoTokenizer) from speechbrain.inference.VAD import VAD # 🔧 Check for CUDA device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 🛠 Load Voice Activity Detection (VAD) model vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty", savedir="vad_model") # 🔍 Function to clean up memory def clean_up_memory(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # 🎙 Load Wav2Vec2 ASR model asr_model_name = "facebook/wav2vec2-large-960h" processor = AutoProcessor.from_pretrained(asr_model_name) w2v2_model = AutoModelForCTC.from_pretrained(asr_model_name).to(device) w2v2_model.eval() # ✍ Load model for punctuation restoration recap_model_name = "kredor/punctuate-all" recap_tokenizer = AutoTokenizer.from_pretrained(recap_model_name) recap_model = AutoModelForTokenClassification.from_pretrained(recap_model_name).to(device) recap_model.eval() # 📌 Function to add punctuation def recap_sentence(string): tokens = recap_tokenizer(string, return_tensors="pt", padding=True, truncation=True).to(device) with torch.no_grad(): predictions = recap_model(**tokens).logits predicted_ids = torch.argmax(predictions, dim=-1)[0] words = string.split() punctuated_text = [] for word, pred in zip(words, predicted_ids): punctuated_text.append(word + recap_tokenizer.convert_ids_to_tokens([pred.item()])[0]) return " ".join(punctuated_text) # 🎧 Function for chunk-based streaming transcription def transcribe_audio_stream(audio_file, chunk_size=2.0): audio, sr = librosa.load(audio_file, sr=16000) duration = librosa.get_duration(y=audio, sr=sr) transcriptions = [] for start in np.arange(0, duration, chunk_size): end = min(start + chunk_size, duration) chunk = audio[int(start * sr):int(end * sr)] input_values = processor(chunk, return_tensors="pt", sampling_rate=16000).input_values.to(w2v2_model.device) with torch.no_grad(): logits = w2v2_model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)[0] transcriptions.append(transcription) return " ".join(transcriptions) # 🎙 Handle both live audio & file uploads def return_prediction_w2v2(file_or_mic): if not file_or_mic: return "", "empty.txt" # Transcribe file transcription = transcribe_audio_stream(file_or_mic) # Add punctuation recap_result = recap_sentence(transcription) # Save result to file download_path = "transcription.txt" with open(download_path, "w") as f: f.write(recap_result) clean_up_memory() return recap_result, download_path # 🖥 Gradio Interface mic_transcribe = gr.Interface( fn=return_prediction_w2v2, inputs=gr.Audio(sources="microphone", type="filepath"), outputs=[gr.Textbox(label="Real-Time Transcription"), gr.File(label="Download Transcript")], allow_flagging="never", live=True ) file_transcribe = gr.Interface( fn=return_prediction_w2v2, inputs=gr.Audio(sources="upload", type="filepath"), outputs=[gr.Textbox(label="File Transcription"), gr.File(label="Download Transcript")], allow_flagging="never", live=False ) # 🎛 Combine into a Gradio app with gr.Blocks() as transcriber_app: gr.Markdown("

CCI Real-Time Sermon Transcription

") gr.TabbedInterface([mic_transcribe, file_transcribe], ["Real-Time (Microphone)", "Upload Audio"]) # 🚀 Run the Gradio app if __name__ == "__main__": transcriber_app.launch()