Spaces:
Sleeping
Sleeping
File size: 3,863 Bytes
283bd52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import gc
import torch
import librosa
import numpy as np
import gradio as gr
from transformers import (AutoProcessor, AutoModelForCTC,
AutoModelForTokenClassification, AutoTokenizer)
from speechbrain.inference.VAD import VAD
# π§ Check for CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# π Load Voice Activity Detection (VAD) model
vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty", savedir="vad_model")
# π Function to clean up memory
def clean_up_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# π Load Wav2Vec2 ASR model
asr_model_name = "facebook/wav2vec2-large-960h"
processor = AutoProcessor.from_pretrained(asr_model_name)
w2v2_model = AutoModelForCTC.from_pretrained(asr_model_name).to(device)
w2v2_model.eval()
# β Load model for punctuation restoration
recap_model_name = "kredor/punctuate-all"
recap_tokenizer = AutoTokenizer.from_pretrained(recap_model_name)
recap_model = AutoModelForTokenClassification.from_pretrained(recap_model_name).to(device)
recap_model.eval()
# π Function to add punctuation
def recap_sentence(string):
tokens = recap_tokenizer(string, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
predictions = recap_model(**tokens).logits
predicted_ids = torch.argmax(predictions, dim=-1)[0]
words = string.split()
punctuated_text = []
for word, pred in zip(words, predicted_ids):
punctuated_text.append(word + recap_tokenizer.convert_ids_to_tokens([pred.item()])[0])
return " ".join(punctuated_text)
# π§ Function for chunk-based streaming transcription
def transcribe_audio_stream(audio_file, chunk_size=2.0):
audio, sr = librosa.load(audio_file, sr=16000)
duration = librosa.get_duration(y=audio, sr=sr)
transcriptions = []
for start in np.arange(0, duration, chunk_size):
end = min(start + chunk_size, duration)
chunk = audio[int(start * sr):int(end * sr)]
input_values = processor(chunk, return_tensors="pt", sampling_rate=16000).input_values.to(w2v2_model.device)
with torch.no_grad():
logits = w2v2_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
transcriptions.append(transcription)
return " ".join(transcriptions)
# π Handle both live audio & file uploads
def return_prediction_w2v2(file_or_mic):
if not file_or_mic:
return "", "empty.txt"
# Transcribe file
transcription = transcribe_audio_stream(file_or_mic)
# Add punctuation
recap_result = recap_sentence(transcription)
# Save result to file
download_path = "transcription.txt"
with open(download_path, "w") as f:
f.write(recap_result)
clean_up_memory()
return recap_result, download_path
# π₯ Gradio Interface
mic_transcribe = gr.Interface(
fn=return_prediction_w2v2,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=[gr.Textbox(label="Real-Time Transcription"), gr.File(label="Download Transcript")],
allow_flagging="never",
live=True
)
file_transcribe = gr.Interface(
fn=return_prediction_w2v2,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=[gr.Textbox(label="File Transcription"), gr.File(label="Download Transcript")],
allow_flagging="never",
live=False
)
# π Combine into a Gradio app
with gr.Blocks() as transcriber_app:
gr.Markdown("<h2>CCI Real-Time Sermon Transcription</h2>")
gr.TabbedInterface([mic_transcribe, file_transcribe],
["Real-Time (Microphone)", "Upload Audio"])
# π Run the Gradio app
if __name__ == "__main__":
transcriber_app.launch()
|