File size: 3,863 Bytes
283bd52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import gc
import torch
import librosa
import numpy as np
import gradio as gr
from transformers import (AutoProcessor, AutoModelForCTC, 
                          AutoModelForTokenClassification, AutoTokenizer)
from speechbrain.inference.VAD import VAD

# πŸ”§ Check for CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# πŸ›  Load Voice Activity Detection (VAD) model
vad_model = VAD.from_hparams(source="speechbrain/vad-crdnn-libriparty", savedir="vad_model")

# πŸ” Function to clean up memory
def clean_up_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# πŸŽ™ Load Wav2Vec2 ASR model
asr_model_name = "facebook/wav2vec2-large-960h"
processor = AutoProcessor.from_pretrained(asr_model_name)
w2v2_model = AutoModelForCTC.from_pretrained(asr_model_name).to(device)
w2v2_model.eval()

# ✍ Load model for punctuation restoration
recap_model_name = "kredor/punctuate-all"
recap_tokenizer = AutoTokenizer.from_pretrained(recap_model_name)
recap_model = AutoModelForTokenClassification.from_pretrained(recap_model_name).to(device)
recap_model.eval()

# πŸ“Œ Function to add punctuation
def recap_sentence(string):
    tokens = recap_tokenizer(string, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        predictions = recap_model(**tokens).logits

    predicted_ids = torch.argmax(predictions, dim=-1)[0]
    words = string.split()
    punctuated_text = []

    for word, pred in zip(words, predicted_ids):
        punctuated_text.append(word + recap_tokenizer.convert_ids_to_tokens([pred.item()])[0])

    return " ".join(punctuated_text)

# 🎧 Function for chunk-based streaming transcription
def transcribe_audio_stream(audio_file, chunk_size=2.0):
    audio, sr = librosa.load(audio_file, sr=16000)
    duration = librosa.get_duration(y=audio, sr=sr)
    transcriptions = []

    for start in np.arange(0, duration, chunk_size):
        end = min(start + chunk_size, duration)
        chunk = audio[int(start * sr):int(end * sr)]
        
        input_values = processor(chunk, return_tensors="pt", sampling_rate=16000).input_values.to(w2v2_model.device)
        
        with torch.no_grad():
            logits = w2v2_model(input_values).logits

        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)[0]
        transcriptions.append(transcription)

    return " ".join(transcriptions)

# πŸŽ™ Handle both live audio & file uploads
def return_prediction_w2v2(file_or_mic):
    if not file_or_mic:
        return "", "empty.txt"

    # Transcribe file
    transcription = transcribe_audio_stream(file_or_mic)

    # Add punctuation
    recap_result = recap_sentence(transcription)

    # Save result to file
    download_path = "transcription.txt"
    with open(download_path, "w") as f:
        f.write(recap_result)

    clean_up_memory()
    return recap_result, download_path

# πŸ–₯ Gradio Interface
mic_transcribe = gr.Interface(
    fn=return_prediction_w2v2,
    inputs=gr.Audio(sources="microphone", type="filepath"),
    outputs=[gr.Textbox(label="Real-Time Transcription"), gr.File(label="Download Transcript")],
    allow_flagging="never",
    live=True
)

file_transcribe = gr.Interface(
    fn=return_prediction_w2v2,
    inputs=gr.Audio(sources="upload", type="filepath"),
    outputs=[gr.Textbox(label="File Transcription"), gr.File(label="Download Transcript")],
    allow_flagging="never",
    live=False
)

# πŸŽ› Combine into a Gradio app
with gr.Blocks() as transcriber_app:
    gr.Markdown("<h2>CCI Real-Time Sermon Transcription</h2>")
    gr.TabbedInterface([mic_transcribe, file_transcribe], 
                       ["Real-Time (Microphone)", "Upload Audio"])

# πŸš€ Run the Gradio app
if __name__ == "__main__":
    transcriber_app.launch()