Spaces:
Sleeping
Sleeping
| import os | |
| import gc | |
| import torch | |
| import librosa | |
| import numpy as np | |
| import gradio as gr | |
| from transformers import (AutoProcessor, AutoModelForCTC, | |
| AutoModelForTokenClassification, AutoTokenizer) | |
| from speechbrain.inference.VAD import VAD | |
| # π§ Check for CUDA | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # π Load Voice Activity Detection (VAD) model | |
| vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty", savedir="vad_model") | |
| # π Function to clean up memory | |
| def clean_up_memory(): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # π Load Wav2Vec2 ASR model | |
| asr_model_name = "facebook/wav2vec2-large-960h" | |
| processor = AutoProcessor.from_pretrained(asr_model_name) | |
| w2v2_model = AutoModelForCTC.from_pretrained(asr_model_name).to(device) | |
| w2v2_model.eval() | |
| # β Load model for punctuation restoration | |
| recap_model_name = "kredor/punctuate-all" | |
| recap_tokenizer = AutoTokenizer.from_pretrained(recap_model_name) | |
| recap_model = AutoModelForTokenClassification.from_pretrained(recap_model_name).to(device) | |
| recap_model.eval() | |
| # π Function to add punctuation | |
| def recap_sentence(string): | |
| tokens = recap_tokenizer(string, return_tensors="pt", padding=True, truncation=True).to(device) | |
| with torch.no_grad(): | |
| predictions = recap_model(**tokens).logits | |
| predicted_ids = torch.argmax(predictions, dim=-1)[0] | |
| words = string.split() | |
| punctuated_text = [] | |
| for word, pred in zip(words, predicted_ids): | |
| punctuated_text.append(word + recap_tokenizer.convert_ids_to_tokens([pred.item()])[0]) | |
| return " ".join(punctuated_text) | |
| # π§ Function for chunk-based streaming transcription | |
| def transcribe_audio_stream(audio_file, chunk_size=2.0): | |
| audio, sr = librosa.load(audio_file, sr=16000) | |
| duration = librosa.get_duration(y=audio, sr=sr) | |
| transcriptions = [] | |
| for start in np.arange(0, duration, chunk_size): | |
| end = min(start + chunk_size, duration) | |
| chunk = audio[int(start * sr):int(end * sr)] | |
| input_values = processor(chunk, return_tensors="pt", sampling_rate=16000).input_values.to(w2v2_model.device) | |
| with torch.no_grad(): | |
| logits = w2v2_model(input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(predicted_ids)[0] | |
| transcriptions.append(transcription) | |
| return " ".join(transcriptions) | |
| # π Handle both live audio & file uploads | |
| def return_prediction_w2v2(file_or_mic): | |
| if not file_or_mic: | |
| return "", "empty.txt" | |
| # Transcribe file | |
| transcription = transcribe_audio_stream(file_or_mic) | |
| # Add punctuation | |
| recap_result = recap_sentence(transcription) | |
| # Save result to file | |
| download_path = "transcription.txt" | |
| with open(download_path, "w") as f: | |
| f.write(recap_result) | |
| clean_up_memory() | |
| return recap_result, download_path | |
| # π₯ Gradio Interface | |
| mic_transcribe = gr.Interface( | |
| fn=return_prediction_w2v2, | |
| inputs=gr.Audio(sources="microphone", type="filepath"), | |
| outputs=[gr.Textbox(label="Real-Time Transcription"), gr.File(label="Download Transcript")], | |
| allow_flagging="never", | |
| live=True | |
| ) | |
| file_transcribe = gr.Interface( | |
| fn=return_prediction_w2v2, | |
| inputs=gr.Audio(sources="upload", type="filepath"), | |
| outputs=[gr.Textbox(label="File Transcription"), gr.File(label="Download Transcript")], | |
| allow_flagging="never", | |
| live=False | |
| ) | |
| # π Combine into a Gradio app | |
| with gr.Blocks() as transcriber_app: | |
| gr.Markdown("<h2>CCI Real-Time Sermon Transcription</h2>") | |
| gr.TabbedInterface([mic_transcribe, file_transcribe], | |
| ["Real-Time (Microphone)", "Upload Audio"]) | |
| # π Run the Gradio app | |
| if __name__ == "__main__": | |
| transcriber_app.launch() | |