stt_demo / app.py
ganga4364's picture
Update app.py
482d6e9 verified
raw
history blame
7.93 kB
#!/usr/bin/env python3
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datetime import timedelta
import os
import shutil
from pathlib import Path
import logging
# Constants and Configuration
SAMPLE_RATE = 16000
MODEL_NAME = "openpecha/general_stt_base_model"
title = "# Tibetan Speech-to-Text with Subtitles"
description = """
This application transcribes Tibetan audio files and generates subtitles using:
- Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings
- Silero VAD for voice activity detection
- Generates both SRT and WebVTT subtitle formats
"""
css = """
.result {display:flex;flex-direction:column}
.result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
.result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
.result_item_error {background-color:#ff7070;color:white;align-self:start}
.player-container {margin: 20px 0;}
.player-container audio {width: 100%;}
"""
# Initialize models
def init_models():
# Load Silero VAD
vad_model, utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True
)
get_speech_ts = utils[0]
# Load Wav2Vec2 model
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
model.eval()
return vad_model, get_speech_ts, model, processor
# Initialize models globally
vad_model, get_speech_ts, model, processor = init_models()
def format_timestamp(seconds, format_type="srt"):
"""Convert seconds to SRT or WebVTT timestamp format"""
td = timedelta(seconds=seconds)
hours = td.seconds // 3600
minutes = (td.seconds % 3600) // 60
seconds = td.seconds % 60
milliseconds = round(td.microseconds / 1000)
if format_type == "srt":
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
else: # webvtt
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"):
"""Create SRT or WebVTT subtitle file"""
with open(output_path, 'w', encoding='utf-8') as f:
if format_type == "vtt":
f.write("WEBVTT\n\n")
for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
if format_type == "srt":
f.write(f"{i}\n")
f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n")
f.write(f"{text}\n\n")
else:
f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n")
f.write(f"{text}\n\n")
def build_html_output(s: str, style: str = "result_item_success"):
return f"""
<div class='result'>
<div class='result_item {style}'>
{s}
</div>
</div>
"""
def create_preview_player(audio_path, vtt_path):
# Create an HTML preview with audio player and subtitles
# Convert file paths to relative URLs that Gradio can serve
audio_url = f"file={audio_path}"
vtt_url = f"file={vtt_path}"
html_content = f"""
<div class="audio-player">
<audio controls style="width: 100%;">
<source src="{audio_url}" type="audio/wav">
<track kind="subtitles" src="{vtt_url}" default>
Your browser does not support the audio element.
</audio>
</div>
"""
return html_content
def process_audio(audio_path: str):
if audio_path is None or audio_path == "":
return (
build_html_output(
"Please upload an audio file first",
"result_item_error",
),
None,
None,
"",
"",
)
logging.info(f"Processing audio file: {audio_path}")
try:
# Load and resample audio to 16kHz mono
wav, sr = torchaudio.load(audio_path)
if sr != SAMPLE_RATE:
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
wav = wav.mean(dim=0) # convert to mono
wav_np = wav.numpy()
# Get speech timestamps using Silero VAD
speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE)
if not speech_timestamps:
return (
build_html_output("No speech detected", "result_item_error"),
None,
None,
"",
"",
)
timestamps_with_text = []
transcriptions = []
for ts in speech_timestamps:
start, end = ts['start'], ts['end']
segment = wav[start:end]
if segment.dim() > 1:
segment = segment.squeeze()
inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
transcriptions.append(transcription)
timestamps_with_text.append((start, end, transcription))
# Generate subtitle files
base_path = os.path.splitext(audio_path)[0]
srt_path = f"{base_path}.srt"
vtt_path = f"{base_path}.vtt"
create_subtitle_file(timestamps_with_text, srt_path, "srt")
create_subtitle_file(timestamps_with_text, vtt_path, "vtt")
# Return the file paths directly
srt_file = srt_path
vtt_file = vtt_path
# Create preview player with the file paths
preview_html = create_preview_player(audio_path, vtt_path)
all_text = " ".join(transcriptions)
return (
build_html_output(
"Transcription completed! You can now:\n1. Download the SRT/VTT files\n2. Play the audio with subtitles below",
"result_item_success"
),
srt_file,
vtt_file,
preview_html,
all_text,
)
except Exception as e:
logging.error(f"Error processing audio: {str(e)}")
return (
build_html_output(
f"Error processing audio: {str(e)}",
"result_item_error"
),
None,
None,
"",
"",
)
demo = gr.Blocks(css=css)
with demo:
gr.Markdown(title)
with gr.Tabs():
with gr.TabItem("Upload Audio"):
audio_input = gr.Audio(
sources=["upload"],
type="filepath",
label="Upload audio file",
)
process_button = gr.Button("Generate Subtitles")
with gr.Column():
info_output = gr.HTML(label="Status")
srt_output = gr.File(label="SRT Subtitle File")
vtt_output = gr.File(label="WebVTT Subtitle File")
preview_output = gr.HTML(label="Preview Player")
text_output = gr.Textbox(
label="Full Transcription",
placeholder="Transcribed text will appear here...",
lines=5
)
process_button.click(
process_audio,
inputs=[audio_input],
outputs=[
info_output,
srt_output,
vtt_output,
preview_output,
text_output,
],
)
gr.Markdown(description)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
demo.launch(share=True)