#!/usr/bin/env python3 import gradio as gr import torch import torchaudio import numpy as np from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from datetime import timedelta import os import shutil from pathlib import Path import logging # Constants and Configuration SAMPLE_RATE = 16000 MODEL_NAME = "openpecha/general_stt_base_model" title = "# Tibetan Speech-to-Text with Subtitles" description = """ This application transcribes Tibetan audio files and generates subtitles using: - Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings - Silero VAD for voice activity detection - Generates both SRT and WebVTT subtitle formats """ css = """ .result {display:flex;flex-direction:column} .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%} .result_item_success {background-color:mediumaquamarine;color:white;align-self:start} .result_item_error {background-color:#ff7070;color:white;align-self:start} .player-container {margin: 20px 0;} .player-container audio {width: 100%;} """ # Initialize models def init_models(): # Load Silero VAD vad_model, utils = torch.hub.load( repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True ) get_speech_ts = utils[0] # Load Wav2Vec2 model model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME) processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME) model.eval() return vad_model, get_speech_ts, model, processor # Initialize models globally vad_model, get_speech_ts, model, processor = init_models() def format_timestamp(seconds, format_type="srt"): """Convert seconds to SRT or WebVTT timestamp format""" td = timedelta(seconds=seconds) hours = td.seconds // 3600 minutes = (td.seconds % 3600) // 60 seconds = td.seconds % 60 milliseconds = round(td.microseconds / 1000) if format_type == "srt": return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" else: # webvtt return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"): """Create SRT or WebVTT subtitle file""" with open(output_path, 'w', encoding='utf-8') as f: if format_type == "vtt": f.write("WEBVTT\n\n") for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1): if format_type == "srt": f.write(f"{i}\n") f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n") f.write(f"{text}\n\n") else: f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n") f.write(f"{text}\n\n") def build_html_output(s: str, style: str = "result_item_success"): return f"""
{s}
""" def create_preview_player(audio_path, vtt_path): # Create an HTML preview with audio player and subtitles # Convert file paths to relative URLs that Gradio can serve audio_url = f"file={audio_path}" vtt_url = f"file={vtt_path}" html_content = f"""
""" return html_content def process_audio(audio_path: str): if audio_path is None or audio_path == "": return ( build_html_output( "Please upload an audio file first", "result_item_error", ), None, None, "", "", ) logging.info(f"Processing audio file: {audio_path}") try: # Load and resample audio to 16kHz mono wav, sr = torchaudio.load(audio_path) if sr != SAMPLE_RATE: wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav) wav = wav.mean(dim=0) # convert to mono wav_np = wav.numpy() # Get speech timestamps using Silero VAD speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE) if not speech_timestamps: return ( build_html_output("No speech detected", "result_item_error"), None, None, "", "", ) timestamps_with_text = [] transcriptions = [] for ts in speech_timestamps: start, end = ts['start'], ts['end'] segment = wav[start:end] if segment.dim() > 1: segment = segment.squeeze() inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True) with torch.no_grad(): logits = model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.decode(predicted_ids[0]) transcriptions.append(transcription) timestamps_with_text.append((start, end, transcription)) # Generate subtitle files base_path = os.path.splitext(audio_path)[0] srt_path = f"{base_path}.srt" vtt_path = f"{base_path}.vtt" create_subtitle_file(timestamps_with_text, srt_path, "srt") create_subtitle_file(timestamps_with_text, vtt_path, "vtt") # Return the file paths directly srt_file = srt_path vtt_file = vtt_path # Create preview player with the file paths preview_html = create_preview_player(audio_path, vtt_path) all_text = " ".join(transcriptions) return ( build_html_output( "Transcription completed! You can now:\n1. Download the SRT/VTT files\n2. Play the audio with subtitles below", "result_item_success" ), srt_file, vtt_file, preview_html, all_text, ) except Exception as e: logging.error(f"Error processing audio: {str(e)}") return ( build_html_output( f"Error processing audio: {str(e)}", "result_item_error" ), None, None, "", "", ) demo = gr.Blocks(css=css) with demo: gr.Markdown(title) with gr.Tabs(): with gr.TabItem("Upload Audio"): audio_input = gr.Audio( sources=["upload"], type="filepath", label="Upload audio file", ) process_button = gr.Button("Generate Subtitles") with gr.Column(): info_output = gr.HTML(label="Status") srt_output = gr.File(label="SRT Subtitle File") vtt_output = gr.File(label="WebVTT Subtitle File") preview_output = gr.HTML(label="Preview Player") text_output = gr.Textbox( label="Full Transcription", placeholder="Transcribed text will appear here...", lines=5 ) process_button.click( process_audio, inputs=[audio_input], outputs=[ info_output, srt_output, vtt_output, preview_output, text_output, ], ) gr.Markdown(description) if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) demo.launch(share=True)