File size: 5,079 Bytes
f78a75f 5062932 f78a75f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datetime import timedelta
import os
import shutil
from pathlib import Path
# Load Silero VAD
vad_model, utils = torch.hub.load(
repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True
)
(get_speech_ts, _, _, _, _) = utils
# Load Wav2Vec2 model
model_name = "ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000"
model = Wav2Vec2ForCTC.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name)
model.eval()
SAMPLE_RATE = 16000
def format_timestamp(seconds, format_type="srt"):
"""Convert seconds to SRT or WebVTT timestamp format"""
td = timedelta(seconds=seconds)
hours = td.seconds // 3600
minutes = (td.seconds % 3600) // 60
seconds = td.seconds % 60
milliseconds = round(td.microseconds / 1000)
if format_type == "srt":
return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
else: # webvtt
return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"):
"""Create SRT or WebVTT subtitle file"""
with open(output_path, 'w', encoding='utf-8') as f:
if format_type == "vtt":
f.write("WEBVTT\n\n")
for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
if format_type == "srt":
f.write(f"{i}\n")
f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n")
f.write(f"{text}\n\n")
else:
f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n")
f.write(f"{text}\n\n")
def create_preview_html(audio_path, vtt_path):
"""Create an HTML preview file for audio with subtitles"""
static_dir = Path("static")
static_dir.mkdir(exist_ok=True)
# Copy files to static directory with friendly names
audio_filename = Path(audio_path).name
vtt_filename = Path(vtt_path).name
new_audio_path = static_dir / audio_filename
new_vtt_path = static_dir / vtt_filename
shutil.copy2(audio_path, new_audio_path)
shutil.copy2(vtt_path, new_vtt_path)
# Read HTML template
with open("player.html", "r") as f:
html_content = f.read()
# Replace placeholders
html_content = html_content.replace("{{ audio_path }}", f"static/{audio_filename}")
html_content = html_content.replace("{{ vtt_path }}", f"static/{vtt_filename}")
# Save preview HTML
preview_path = static_dir / "preview.html"
with open(preview_path, "w") as f:
f.write(html_content)
return str(preview_path)
def transcribe_with_vad(audio_path):
# Load and resample audio to 16kHz mono
wav, sr = torchaudio.load(audio_path)
if sr != SAMPLE_RATE:
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
wav = wav.mean(dim=0) # convert to mono
wav_np = wav.numpy()
# Get speech timestamps using Silero VAD
speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE)
if not speech_timestamps:
return "No speech detected.", None, None, None
timestamps_with_text = []
transcriptions = []
for ts in speech_timestamps:
start, end = ts['start'], ts['end']
segment = wav[start:end]
if segment.dim() > 1:
segment = segment.squeeze()
inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
transcriptions.append(transcription)
timestamps_with_text.append((start, end, transcription))
# Generate subtitle files
base_path = os.path.splitext(audio_path)[0]
srt_path = f"{base_path}.srt"
vtt_path = f"{base_path}.vtt"
create_subtitle_file(timestamps_with_text, srt_path, "srt")
create_subtitle_file(timestamps_with_text, vtt_path, "vtt")
# Create preview HTML
preview_html = create_preview_html(audio_path, vtt_path)
return " ".join(transcriptions), srt_path, vtt_path, preview_html
# Gradio Interface
demo = gr.Interface(
fn=transcribe_with_vad,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload or Record"),
outputs=[
gr.Textbox(label="Transcription"),
gr.File(label="SRT Subtitle File"),
gr.File(label="WebVTT Subtitle File"),
gr.HTML(label="Preview Player")
],
title="Smart Speech-to-Text with VAD and Subtitles",
description="Transcribe long audio using ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000 and Silero VAD. Generates SRT and WebVTT subtitle files."
)
if __name__ == "__main__":
demo.launch(share=True)
|