File size: 7,925 Bytes
3543a1c f78a75f 3543a1c f78a75f 3543a1c 482d6e9 f78a75f 3543a1c f78a75f 3543a1c f78a75f 3543a1c 8e2dbbc bf72c41 8e2dbbc bf72c41 f78a75f bf72c41 f78a75f 3543a1c 4ce0e75 3543a1c 4ce0e75 8e2dbbc 4ce0e75 8e2dbbc 4ce0e75 f78a75f 3543a1c 4ce0e75 3543a1c f78a75f 3543a1c f78a75f 3543a1c 4ce0e75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
#!/usr/bin/env python3
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datetime import timedelta
import os
import shutil
from pathlib import Path
import logging
# Constants and Configuration
SAMPLE_RATE = 16000
MODEL_NAME = "openpecha/general_stt_base_model"
title = "# Tibetan Speech-to-Text with Subtitles"
description = """
This application transcribes Tibetan audio files and generates subtitles using:
- Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings
- Silero VAD for voice activity detection
- Generates both SRT and WebVTT subtitle formats
"""
css = """
.result {display:flex;flex-direction:column}
.result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
.result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
.result_item_error {background-color:#ff7070;color:white;align-self:start}
.player-container {margin: 20px 0;}
.player-container audio {width: 100%;}
"""
# Initialize models
def init_models():
# Load Silero VAD
vad_model, utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True
)
get_speech_ts = utils[0]
# Load Wav2Vec2 model
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
model.eval()
return vad_model, get_speech_ts, model, processor
# Initialize models globally
vad_model, get_speech_ts, model, processor = init_models()
def format_timestamp(seconds, format_type="srt"):
"""Convert seconds to SRT or WebVTT timestamp format"""
td = timedelta(seconds=seconds)
hours = td.seconds // 3600
minutes = (td.seconds % 3600) // 60
seconds = td.seconds % 60
milliseconds = round(td.microseconds / 1000)
if format_type == "srt":
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
else: # webvtt
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"):
"""Create SRT or WebVTT subtitle file"""
with open(output_path, 'w', encoding='utf-8') as f:
if format_type == "vtt":
f.write("WEBVTT\n\n")
for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
if format_type == "srt":
f.write(f"{i}\n")
f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n")
f.write(f"{text}\n\n")
else:
f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n")
f.write(f"{text}\n\n")
def build_html_output(s: str, style: str = "result_item_success"):
return f"""
<div class='result'>
<div class='result_item {style}'>
{s}
</div>
</div>
"""
def create_preview_player(audio_path, vtt_path):
# Create an HTML preview with audio player and subtitles
# Convert file paths to relative URLs that Gradio can serve
audio_url = f"file={audio_path}"
vtt_url = f"file={vtt_path}"
html_content = f"""
<div class="audio-player">
<audio controls style="width: 100%;">
<source src="{audio_url}" type="audio/wav">
<track kind="subtitles" src="{vtt_url}" default>
Your browser does not support the audio element.
</audio>
</div>
"""
return html_content
def process_audio(audio_path: str):
if audio_path is None or audio_path == "":
return (
build_html_output(
"Please upload an audio file first",
"result_item_error",
),
None,
None,
"",
"",
)
logging.info(f"Processing audio file: {audio_path}")
try:
# Load and resample audio to 16kHz mono
wav, sr = torchaudio.load(audio_path)
if sr != SAMPLE_RATE:
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
wav = wav.mean(dim=0) # convert to mono
wav_np = wav.numpy()
# Get speech timestamps using Silero VAD
speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE)
if not speech_timestamps:
return (
build_html_output("No speech detected", "result_item_error"),
None,
None,
"",
"",
)
timestamps_with_text = []
transcriptions = []
for ts in speech_timestamps:
start, end = ts['start'], ts['end']
segment = wav[start:end]
if segment.dim() > 1:
segment = segment.squeeze()
inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
transcriptions.append(transcription)
timestamps_with_text.append((start, end, transcription))
# Generate subtitle files
base_path = os.path.splitext(audio_path)[0]
srt_path = f"{base_path}.srt"
vtt_path = f"{base_path}.vtt"
create_subtitle_file(timestamps_with_text, srt_path, "srt")
create_subtitle_file(timestamps_with_text, vtt_path, "vtt")
# Return the file paths directly
srt_file = srt_path
vtt_file = vtt_path
# Create preview player with the file paths
preview_html = create_preview_player(audio_path, vtt_path)
all_text = " ".join(transcriptions)
return (
build_html_output(
"Transcription completed! You can now:\n1. Download the SRT/VTT files\n2. Play the audio with subtitles below",
"result_item_success"
),
srt_file,
vtt_file,
preview_html,
all_text,
)
except Exception as e:
logging.error(f"Error processing audio: {str(e)}")
return (
build_html_output(
f"Error processing audio: {str(e)}",
"result_item_error"
),
None,
None,
"",
"",
)
demo = gr.Blocks(css=css)
with demo:
gr.Markdown(title)
with gr.Tabs():
with gr.TabItem("Upload Audio"):
audio_input = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload audio file",
)
process_button = gr.Button("Generate Subtitles")
with gr.Column():
info_output = gr.HTML(label="Status")
srt_output = gr.File(label="SRT Subtitle File")
vtt_output = gr.File(label="WebVTT Subtitle File")
preview_output = gr.HTML(label="Preview Player")
text_output = gr.Textbox(
label="Full Transcription",
placeholder="Transcribed text will appear here...",
lines=5
)
process_button.click(
process_audio,
inputs=[audio_input],
outputs=[
info_output,
srt_output,
vtt_output,
preview_output,
text_output,
],
)
gr.Markdown(description)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
demo.launch(share=True)
|