Spaces:
Sleeping
Sleeping
import os | |
import gc | |
import torch | |
import librosa | |
import numpy as np | |
import gradio as gr | |
from transformers import (AutoProcessor, AutoModelForCTC, | |
AutoModelForTokenClassification, AutoTokenizer) | |
from speechbrain.inference.VAD import VAD | |
# π§ Check for CUDA | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# π Load Voice Activity Detection (VAD) model | |
vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty", savedir="vad_model") | |
# π Function to clean up memory | |
def clean_up_memory(): | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# π Load Wav2Vec2 ASR model | |
asr_model_name = "facebook/wav2vec2-large-960h" | |
processor = AutoProcessor.from_pretrained(asr_model_name) | |
w2v2_model = AutoModelForCTC.from_pretrained(asr_model_name).to(device) | |
w2v2_model.eval() | |
# β Load model for punctuation restoration | |
recap_model_name = "kredor/punctuate-all" | |
recap_tokenizer = AutoTokenizer.from_pretrained(recap_model_name) | |
recap_model = AutoModelForTokenClassification.from_pretrained(recap_model_name).to(device) | |
recap_model.eval() | |
# π Function to add punctuation | |
def recap_sentence(string): | |
tokens = recap_tokenizer(string, return_tensors="pt", padding=True, truncation=True).to(device) | |
with torch.no_grad(): | |
predictions = recap_model(**tokens).logits | |
predicted_ids = torch.argmax(predictions, dim=-1)[0] | |
words = string.split() | |
punctuated_text = [] | |
for word, pred in zip(words, predicted_ids): | |
punctuated_text.append(word + recap_tokenizer.convert_ids_to_tokens([pred.item()])[0]) | |
return " ".join(punctuated_text) | |
# π§ Function for chunk-based streaming transcription | |
def transcribe_audio_stream(audio_file, chunk_size=2.0): | |
audio, sr = librosa.load(audio_file, sr=16000) | |
duration = librosa.get_duration(y=audio, sr=sr) | |
transcriptions = [] | |
for start in np.arange(0, duration, chunk_size): | |
end = min(start + chunk_size, duration) | |
chunk = audio[int(start * sr):int(end * sr)] | |
input_values = processor(chunk, return_tensors="pt", sampling_rate=16000).input_values.to(w2v2_model.device) | |
with torch.no_grad(): | |
logits = w2v2_model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids)[0] | |
transcriptions.append(transcription) | |
return " ".join(transcriptions) | |
# π Handle both live audio & file uploads | |
def return_prediction_w2v2(file_or_mic): | |
if not file_or_mic: | |
return "", "empty.txt" | |
# Transcribe file | |
transcription = transcribe_audio_stream(file_or_mic) | |
# Add punctuation | |
recap_result = recap_sentence(transcription) | |
# Save result to file | |
download_path = "transcription.txt" | |
with open(download_path, "w") as f: | |
f.write(recap_result) | |
clean_up_memory() | |
return recap_result, download_path | |
# π₯ Gradio Interface | |
mic_transcribe = gr.Interface( | |
fn=return_prediction_w2v2, | |
inputs=gr.Audio(sources="microphone", type="filepath"), | |
outputs=[gr.Textbox(label="Real-Time Transcription"), gr.File(label="Download Transcript")], | |
allow_flagging="never", | |
live=True | |
) | |
file_transcribe = gr.Interface( | |
fn=return_prediction_w2v2, | |
inputs=gr.Audio(sources="upload", type="filepath"), | |
outputs=[gr.Textbox(label="File Transcription"), gr.File(label="Download Transcript")], | |
allow_flagging="never", | |
live=False | |
) | |
# π Combine into a Gradio app | |
with gr.Blocks() as transcriber_app: | |
gr.Markdown("<h2>CCI Real-Time Sermon Transcription</h2>") | |
gr.TabbedInterface([mic_transcribe, file_transcribe], | |
["Real-Time (Microphone)", "Upload Audio"]) | |
# π Run the Gradio app | |
if __name__ == "__main__": | |
transcriber_app.launch() | |