stt_demo / app.py
ganga4364's picture
Update app.py
bf72c41 verified
raw
history blame
5.12 kB
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datetime import timedelta
import os
import shutil
from pathlib import Path
# Load Silero VAD
vad_model, utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True
)
(get_speech_ts, _, _, _, _) = utils
# Load Wav2Vec2 model
model_name = "ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000"
model = Wav2Vec2ForCTC.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name)
model.eval()
SAMPLE_RATE = 16000
def format_timestamp(seconds, format_type="srt"):
"""Convert seconds to SRT or WebVTT timestamp format"""
td = timedelta(seconds=seconds)
hours = td.seconds // 3600
minutes = (td.seconds % 3600) // 60
seconds = td.seconds % 60
milliseconds = round(td.microseconds / 1000)
if format_type == "srt":
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
else: # webvtt
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"):
"""Create SRT or WebVTT subtitle file"""
with open(output_path, 'w', encoding='utf-8') as f:
if format_type == "vtt":
f.write("WEBVTT\n\n")
for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
if format_type == "srt":
f.write(f"{i}\n")
f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n")
f.write(f"{text}\n\n")
else:
f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n")
f.write(f"{text}\n\n")
def create_preview_html(audio_path, vtt_path):
"""Create an HTML preview with audio player and subtitles"""
static_dir = Path("static")
static_dir.mkdir(exist_ok=True)
# Copy files to static directory with friendly names
audio_filename = Path(audio_path).name
vtt_filename = Path(vtt_path).name
new_audio_path = static_dir / audio_filename
new_vtt_path = static_dir / vtt_filename
shutil.copy2(audio_path, new_audio_path)
shutil.copy2(vtt_path, new_vtt_path)
# Create direct HTML content
html_content = f"""
<div class="player-container">
<h3>Audio Player with Subtitles</h3>
<audio controls style="width: 100%; margin: 10px 0;">
<source src="file/{new_audio_path}" type="audio/wav">
<track label="English" kind="subtitles" srclang="en" src="file/{new_vtt_path}" default>
Your browser does not support the audio element.
</audio>
</div>
"""
return html_content
def transcribe_with_vad(audio_path):
# Load and resample audio to 16kHz mono
wav, sr = torchaudio.load(audio_path)
if sr != SAMPLE_RATE:
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
wav = wav.mean(dim=0) # convert to mono
wav_np = wav.numpy()
# Get speech timestamps using Silero VAD
speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE)
if not speech_timestamps:
return "No speech detected.", None, None, None
timestamps_with_text = []
transcriptions = []
for ts in speech_timestamps:
start, end = ts['start'], ts['end']
segment = wav[start:end]
if segment.dim() > 1:
segment = segment.squeeze()
inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
transcriptions.append(transcription)
timestamps_with_text.append((start, end, transcription))
# Generate subtitle files
base_path = os.path.splitext(audio_path)[0]
srt_path = f"{base_path}.srt"
vtt_path = f"{base_path}.vtt"
create_subtitle_file(timestamps_with_text, srt_path, "srt")
create_subtitle_file(timestamps_with_text, vtt_path, "vtt")
# Create preview HTML
preview_html = create_preview_html(audio_path, vtt_path)
return " ".join(transcriptions), srt_path, vtt_path, preview_html
# Gradio Interface
demo = gr.Interface(
fn=transcribe_with_vad,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload or Record"),
outputs=[
gr.Textbox(label="Transcription"),
gr.File(label="SRT Subtitle File"),
gr.File(label="WebVTT Subtitle File"),
gr.HTML(label="Preview Player")
],
title="Smart Speech-to-Text with VAD and Subtitles",
description="Transcribe long audio using ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000 and Silero VAD. Generates SRT and WebVTT subtitle files."
)
if __name__ == "__main__":
demo.launch(share=True, file_directories=["static"])