Emmanuel08's picture
app.py
283bd52 verified
import os
import gc
import torch
import librosa
import numpy as np
import gradio as gr
from transformers import (AutoProcessor, AutoModelForCTC,
AutoModelForTokenClassification, AutoTokenizer)
from speechbrain.inference.VAD import VAD
# πŸ”§ Check for CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# πŸ›  Load Voice Activity Detection (VAD) model
vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty", savedir="vad_model")
# πŸ” Function to clean up memory
def clean_up_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# πŸŽ™ Load Wav2Vec2 ASR model
asr_model_name = "facebook/wav2vec2-large-960h"
processor = AutoProcessor.from_pretrained(asr_model_name)
w2v2_model = AutoModelForCTC.from_pretrained(asr_model_name).to(device)
w2v2_model.eval()
# ✍ Load model for punctuation restoration
recap_model_name = "kredor/punctuate-all"
recap_tokenizer = AutoTokenizer.from_pretrained(recap_model_name)
recap_model = AutoModelForTokenClassification.from_pretrained(recap_model_name).to(device)
recap_model.eval()
# πŸ“Œ Function to add punctuation
def recap_sentence(string):
tokens = recap_tokenizer(string, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
predictions = recap_model(**tokens).logits
predicted_ids = torch.argmax(predictions, dim=-1)[0]
words = string.split()
punctuated_text = []
for word, pred in zip(words, predicted_ids):
punctuated_text.append(word + recap_tokenizer.convert_ids_to_tokens([pred.item()])[0])
return " ".join(punctuated_text)
# 🎧 Function for chunk-based streaming transcription
def transcribe_audio_stream(audio_file, chunk_size=2.0):
audio, sr = librosa.load(audio_file, sr=16000)
duration = librosa.get_duration(y=audio, sr=sr)
transcriptions = []
for start in np.arange(0, duration, chunk_size):
end = min(start + chunk_size, duration)
chunk = audio[int(start * sr):int(end * sr)]
input_values = processor(chunk, return_tensors="pt", sampling_rate=16000).input_values.to(w2v2_model.device)
with torch.no_grad():
logits = w2v2_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
transcriptions.append(transcription)
return " ".join(transcriptions)
# πŸŽ™ Handle both live audio & file uploads
def return_prediction_w2v2(file_or_mic):
if not file_or_mic:
return "", "empty.txt"
# Transcribe file
transcription = transcribe_audio_stream(file_or_mic)
# Add punctuation
recap_result = recap_sentence(transcription)
# Save result to file
download_path = "transcription.txt"
with open(download_path, "w") as f:
f.write(recap_result)
clean_up_memory()
return recap_result, download_path
# πŸ–₯ Gradio Interface
mic_transcribe = gr.Interface(
fn=return_prediction_w2v2,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=[gr.Textbox(label="Real-Time Transcription"), gr.File(label="Download Transcript")],
allow_flagging="never",
live=True
)
file_transcribe = gr.Interface(
fn=return_prediction_w2v2,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=[gr.Textbox(label="File Transcription"), gr.File(label="Download Transcript")],
allow_flagging="never",
live=False
)
# πŸŽ› Combine into a Gradio app
with gr.Blocks() as transcriber_app:
gr.Markdown("<h2>CCI Real-Time Sermon Transcription</h2>")
gr.TabbedInterface([mic_transcribe, file_transcribe],
["Real-Time (Microphone)", "Upload Audio"])
# πŸš€ Run the Gradio app
if __name__ == "__main__":
transcriber_app.launch()