|
|
|
|
|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
from datetime import timedelta |
|
import os |
|
import shutil |
|
from pathlib import Path |
|
import logging |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
MODEL_NAME = "openpecha/general_stt_base_model" |
|
|
|
title = "# Tibetan Speech-to-Text with Subtitles" |
|
|
|
description = """ |
|
This application transcribes Tibetan audio files and generates subtitles using: |
|
- Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings |
|
- Silero VAD for voice activity detection |
|
- Generates both SRT and WebVTT subtitle formats |
|
""" |
|
|
|
css = """ |
|
.result {display:flex;flex-direction:column} |
|
.result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%} |
|
.result_item_success {background-color:mediumaquamarine;color:white;align-self:start} |
|
.result_item_error {background-color:#ff7070;color:white;align-self:start} |
|
.player-container {margin: 20px 0;} |
|
.player-container audio {width: 100%;} |
|
""" |
|
|
|
|
|
def init_models(): |
|
|
|
vad_model, utils = torch.hub.load( |
|
repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True |
|
) |
|
get_speech_ts = utils[0] |
|
|
|
|
|
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME) |
|
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME) |
|
model.eval() |
|
|
|
return vad_model, get_speech_ts, model, processor |
|
|
|
|
|
vad_model, get_speech_ts, model, processor = init_models() |
|
|
|
def format_timestamp(seconds, format_type="srt"): |
|
"""Convert seconds to SRT or WebVTT timestamp format""" |
|
td = timedelta(seconds=seconds) |
|
hours = td.seconds // 3600 |
|
minutes = (td.seconds % 3600) // 60 |
|
seconds = td.seconds % 60 |
|
milliseconds = round(td.microseconds / 1000) |
|
|
|
if format_type == "srt": |
|
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" |
|
else: |
|
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" |
|
|
|
def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"): |
|
"""Create SRT or WebVTT subtitle file""" |
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
if format_type == "vtt": |
|
f.write("WEBVTT\n\n") |
|
|
|
for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1): |
|
if format_type == "srt": |
|
f.write(f"{i}\n") |
|
f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n") |
|
f.write(f"{text}\n\n") |
|
else: |
|
f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n") |
|
f.write(f"{text}\n\n") |
|
|
|
def build_html_output(s: str, style: str = "result_item_success"): |
|
return f""" |
|
<div class='result'> |
|
<div class='result_item {style}'> |
|
{s} |
|
</div> |
|
</div> |
|
""" |
|
|
|
def create_preview_player(audio_path, vtt_path): |
|
|
|
|
|
audio_url = f"file={audio_path}" |
|
vtt_url = f"file={vtt_path}" |
|
|
|
html_content = f""" |
|
<div class="audio-player"> |
|
<audio controls style="width: 100%;"> |
|
<source src="{audio_url}" type="audio/wav"> |
|
<track kind="subtitles" src="{vtt_url}" default> |
|
Your browser does not support the audio element. |
|
</audio> |
|
</div> |
|
""" |
|
|
|
return html_content |
|
|
|
def process_audio(audio_path: str): |
|
if audio_path is None or audio_path == "": |
|
return ( |
|
build_html_output( |
|
"Please upload an audio file first", |
|
"result_item_error", |
|
), |
|
None, |
|
None, |
|
"", |
|
"", |
|
) |
|
|
|
logging.info(f"Processing audio file: {audio_path}") |
|
|
|
try: |
|
|
|
wav, sr = torchaudio.load(audio_path) |
|
if sr != SAMPLE_RATE: |
|
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav) |
|
wav = wav.mean(dim=0) |
|
wav_np = wav.numpy() |
|
|
|
|
|
speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE) |
|
if not speech_timestamps: |
|
return ( |
|
build_html_output("No speech detected", "result_item_error"), |
|
None, |
|
None, |
|
"", |
|
"", |
|
) |
|
|
|
timestamps_with_text = [] |
|
transcriptions = [] |
|
|
|
for ts in speech_timestamps: |
|
start, end = ts['start'], ts['end'] |
|
segment = wav[start:end] |
|
if segment.dim() > 1: |
|
segment = segment.squeeze() |
|
|
|
inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.decode(predicted_ids[0]) |
|
transcriptions.append(transcription) |
|
timestamps_with_text.append((start, end, transcription)) |
|
|
|
|
|
base_path = os.path.splitext(audio_path)[0] |
|
srt_path = f"{base_path}.srt" |
|
vtt_path = f"{base_path}.vtt" |
|
|
|
create_subtitle_file(timestamps_with_text, srt_path, "srt") |
|
create_subtitle_file(timestamps_with_text, vtt_path, "vtt") |
|
|
|
|
|
srt_file = srt_path |
|
vtt_file = vtt_path |
|
|
|
|
|
preview_html = create_preview_player(audio_path, vtt_path) |
|
all_text = " ".join(transcriptions) |
|
|
|
return ( |
|
build_html_output( |
|
"Transcription completed! You can now:\n1. Download the SRT/VTT files\n2. Play the audio with subtitles below", |
|
"result_item_success" |
|
), |
|
srt_file, |
|
vtt_file, |
|
preview_html, |
|
all_text, |
|
) |
|
except Exception as e: |
|
logging.error(f"Error processing audio: {str(e)}") |
|
return ( |
|
build_html_output( |
|
f"Error processing audio: {str(e)}", |
|
"result_item_error" |
|
), |
|
None, |
|
None, |
|
"", |
|
"", |
|
) |
|
|
|
demo = gr.Blocks(css=css) |
|
|
|
with demo: |
|
gr.Markdown(title) |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Upload Audio"): |
|
audio_input = gr.Audio( |
|
sources=["upload"], |
|
type="filepath", |
|
label="Upload audio file", |
|
) |
|
process_button = gr.Button("Generate Subtitles") |
|
|
|
with gr.Column(): |
|
info_output = gr.HTML(label="Status") |
|
srt_output = gr.File(label="SRT Subtitle File") |
|
vtt_output = gr.File(label="WebVTT Subtitle File") |
|
preview_output = gr.HTML(label="Preview Player") |
|
text_output = gr.Textbox( |
|
label="Full Transcription", |
|
placeholder="Transcribed text will appear here...", |
|
lines=5 |
|
) |
|
|
|
process_button.click( |
|
process_audio, |
|
inputs=[audio_input], |
|
outputs=[ |
|
info_output, |
|
srt_output, |
|
vtt_output, |
|
preview_output, |
|
text_output, |
|
], |
|
) |
|
|
|
gr.Markdown(description) |
|
|
|
if __name__ == "__main__": |
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
|
logging.basicConfig(format=formatter, level=logging.INFO) |
|
demo.launch(share=True) |
|
|