|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
from datetime import timedelta |
|
import os |
|
import shutil |
|
from pathlib import Path |
|
|
|
|
|
vad_model, utils = torch.hub.load( |
|
repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True |
|
) |
|
(get_speech_ts, _, _, _, _) = utils |
|
|
|
|
|
model_name = "ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000" |
|
model = Wav2Vec2ForCTC.from_pretrained(model_name) |
|
processor = Wav2Vec2Processor.from_pretrained(model_name) |
|
model.eval() |
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
def format_timestamp(seconds, format_type="srt"): |
|
"""Convert seconds to SRT or WebVTT timestamp format""" |
|
td = timedelta(seconds=seconds) |
|
hours = td.seconds // 3600 |
|
minutes = (td.seconds % 3600) // 60 |
|
seconds = td.seconds % 60 |
|
milliseconds = round(td.microseconds / 1000) |
|
|
|
if format_type == "srt": |
|
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" |
|
else: |
|
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" |
|
|
|
def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"): |
|
"""Create SRT or WebVTT subtitle file""" |
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
if format_type == "vtt": |
|
f.write("WEBVTT\n\n") |
|
|
|
for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1): |
|
if format_type == "srt": |
|
f.write(f"{i}\n") |
|
f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n") |
|
f.write(f"{text}\n\n") |
|
else: |
|
f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n") |
|
f.write(f"{text}\n\n") |
|
|
|
def create_preview_html(audio_path, vtt_path): |
|
"""Create an HTML preview with audio player and subtitles""" |
|
static_dir = Path("static") |
|
static_dir.mkdir(exist_ok=True) |
|
|
|
|
|
audio_filename = Path(audio_path).name |
|
vtt_filename = Path(vtt_path).name |
|
new_audio_path = static_dir / audio_filename |
|
new_vtt_path = static_dir / vtt_filename |
|
|
|
shutil.copy2(audio_path, new_audio_path) |
|
shutil.copy2(vtt_path, new_vtt_path) |
|
|
|
|
|
html_content = f""" |
|
<div class="player-container"> |
|
<h3>Audio Player with Subtitles</h3> |
|
<audio controls style="width: 100%; margin: 10px 0;"> |
|
<source src="file/{new_audio_path}" type="audio/wav"> |
|
<track label="English" kind="subtitles" srclang="en" src="file/{new_vtt_path}" default> |
|
Your browser does not support the audio element. |
|
</audio> |
|
</div> |
|
""" |
|
|
|
return html_content |
|
|
|
def transcribe_with_vad(audio_path): |
|
|
|
wav, sr = torchaudio.load(audio_path) |
|
if sr != SAMPLE_RATE: |
|
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav) |
|
wav = wav.mean(dim=0) |
|
wav_np = wav.numpy() |
|
|
|
|
|
speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE) |
|
if not speech_timestamps: |
|
return "No speech detected.", None, None, None |
|
|
|
timestamps_with_text = [] |
|
transcriptions = [] |
|
|
|
for ts in speech_timestamps: |
|
start, end = ts['start'], ts['end'] |
|
segment = wav[start:end] |
|
if segment.dim() > 1: |
|
segment = segment.squeeze() |
|
|
|
inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = processor.decode(predicted_ids[0]) |
|
transcriptions.append(transcription) |
|
timestamps_with_text.append((start, end, transcription)) |
|
|
|
|
|
base_path = os.path.splitext(audio_path)[0] |
|
srt_path = f"{base_path}.srt" |
|
vtt_path = f"{base_path}.vtt" |
|
|
|
create_subtitle_file(timestamps_with_text, srt_path, "srt") |
|
create_subtitle_file(timestamps_with_text, vtt_path, "vtt") |
|
|
|
|
|
preview_html = create_preview_html(audio_path, vtt_path) |
|
|
|
return " ".join(transcriptions), srt_path, vtt_path, preview_html |
|
|
|
|
|
demo = gr.Interface( |
|
fn=transcribe_with_vad, |
|
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload or Record"), |
|
outputs=[ |
|
gr.Textbox(label="Transcription"), |
|
gr.File(label="SRT Subtitle File"), |
|
gr.File(label="WebVTT Subtitle File"), |
|
gr.HTML(label="Preview Player") |
|
], |
|
title="Smart Speech-to-Text with VAD and Subtitles", |
|
description="Transcribe long audio using ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000 and Silero VAD. Generates SRT and WebVTT subtitle files." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True, file_directories=["static"]) |
|
|