Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import numpy as np
|
5 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
6 |
+
from datetime import timedelta
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
# Load Silero VAD
|
12 |
+
vad_model, utils = torch.hub.load(
|
13 |
+
repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True
|
14 |
+
)
|
15 |
+
(get_speech_ts, _, _, _, _) = utils
|
16 |
+
|
17 |
+
# Load Wav2Vec2 model
|
18 |
+
model_name = "ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000"
|
19 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
20 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
21 |
+
model.eval()
|
22 |
+
|
23 |
+
SAMPLE_RATE = 16000
|
24 |
+
|
25 |
+
def format_timestamp(seconds, format_type="srt"):
|
26 |
+
"""Convert seconds to SRT or WebVTT timestamp format"""
|
27 |
+
td = timedelta(seconds=seconds)
|
28 |
+
hours = td.seconds // 3600
|
29 |
+
minutes = (td.seconds % 3600) // 60
|
30 |
+
seconds = td.seconds % 60
|
31 |
+
milliseconds = round(td.microseconds / 1000)
|
32 |
+
|
33 |
+
if format_type == "srt":
|
34 |
+
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
|
35 |
+
else: # webvtt
|
36 |
+
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
37 |
+
|
38 |
+
def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"):
|
39 |
+
"""Create SRT or WebVTT subtitle file"""
|
40 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
41 |
+
if format_type == "vtt":
|
42 |
+
f.write("WEBVTT\n\n")
|
43 |
+
|
44 |
+
for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
|
45 |
+
if format_type == "srt":
|
46 |
+
f.write(f"{i}\n")
|
47 |
+
f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n")
|
48 |
+
f.write(f"{text}\n\n")
|
49 |
+
else:
|
50 |
+
f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n")
|
51 |
+
f.write(f"{text}\n\n")
|
52 |
+
|
53 |
+
def create_preview_html(audio_path, vtt_path):
|
54 |
+
"""Create an HTML preview file for audio with subtitles"""
|
55 |
+
static_dir = Path("static")
|
56 |
+
static_dir.mkdir(exist_ok=True)
|
57 |
+
|
58 |
+
# Copy files to static directory with friendly names
|
59 |
+
audio_filename = Path(audio_path).name
|
60 |
+
vtt_filename = Path(vtt_path).name
|
61 |
+
new_audio_path = static_dir / audio_filename
|
62 |
+
new_vtt_path = static_dir / vtt_filename
|
63 |
+
|
64 |
+
shutil.copy2(audio_path, new_audio_path)
|
65 |
+
shutil.copy2(vtt_path, new_vtt_path)
|
66 |
+
|
67 |
+
# Read HTML template
|
68 |
+
with open("templates/player.html", "r") as f:
|
69 |
+
html_content = f.read()
|
70 |
+
|
71 |
+
# Replace placeholders
|
72 |
+
html_content = html_content.replace("{{ audio_path }}", f"static/{audio_filename}")
|
73 |
+
html_content = html_content.replace("{{ vtt_path }}", f"static/{vtt_filename}")
|
74 |
+
|
75 |
+
# Save preview HTML
|
76 |
+
preview_path = static_dir / "preview.html"
|
77 |
+
with open(preview_path, "w") as f:
|
78 |
+
f.write(html_content)
|
79 |
+
|
80 |
+
return str(preview_path)
|
81 |
+
|
82 |
+
def transcribe_with_vad(audio_path):
|
83 |
+
# Load and resample audio to 16kHz mono
|
84 |
+
wav, sr = torchaudio.load(audio_path)
|
85 |
+
if sr != SAMPLE_RATE:
|
86 |
+
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
|
87 |
+
wav = wav.mean(dim=0) # convert to mono
|
88 |
+
wav_np = wav.numpy()
|
89 |
+
|
90 |
+
# Get speech timestamps using Silero VAD
|
91 |
+
speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE)
|
92 |
+
if not speech_timestamps:
|
93 |
+
return "No speech detected.", None, None, None
|
94 |
+
|
95 |
+
timestamps_with_text = []
|
96 |
+
transcriptions = []
|
97 |
+
|
98 |
+
for ts in speech_timestamps:
|
99 |
+
start, end = ts['start'], ts['end']
|
100 |
+
segment = wav[start:end]
|
101 |
+
if segment.dim() > 1:
|
102 |
+
segment = segment.squeeze()
|
103 |
+
|
104 |
+
inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
|
105 |
+
with torch.no_grad():
|
106 |
+
logits = model(**inputs).logits
|
107 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
108 |
+
transcription = processor.decode(predicted_ids[0])
|
109 |
+
transcriptions.append(transcription)
|
110 |
+
timestamps_with_text.append((start, end, transcription))
|
111 |
+
|
112 |
+
# Generate subtitle files
|
113 |
+
base_path = os.path.splitext(audio_path)[0]
|
114 |
+
srt_path = f"{base_path}.srt"
|
115 |
+
vtt_path = f"{base_path}.vtt"
|
116 |
+
|
117 |
+
create_subtitle_file(timestamps_with_text, srt_path, "srt")
|
118 |
+
create_subtitle_file(timestamps_with_text, vtt_path, "vtt")
|
119 |
+
|
120 |
+
# Create preview HTML
|
121 |
+
preview_html = create_preview_html(audio_path, vtt_path)
|
122 |
+
|
123 |
+
return " ".join(transcriptions), srt_path, vtt_path, preview_html
|
124 |
+
|
125 |
+
# Gradio Interface
|
126 |
+
demo = gr.Interface(
|
127 |
+
fn=transcribe_with_vad,
|
128 |
+
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload or Record"),
|
129 |
+
outputs=[
|
130 |
+
gr.Textbox(label="Transcription"),
|
131 |
+
gr.File(label="SRT Subtitle File"),
|
132 |
+
gr.File(label="WebVTT Subtitle File"),
|
133 |
+
gr.HTML(label="Preview Player")
|
134 |
+
],
|
135 |
+
title="Smart Speech-to-Text with VAD and Subtitles",
|
136 |
+
description="Transcribe long audio using ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000 and Silero VAD. Generates SRT and WebVTT subtitle files."
|
137 |
+
)
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
demo.launch(share=True)
|