File size: 5,079 Bytes
f78a75f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5062932
f78a75f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from datetime import timedelta
import os
import shutil
from pathlib import Path

# Load Silero VAD
vad_model, utils = torch.hub.load(
    repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True
)
(get_speech_ts, _, _, _, _) = utils

# Load Wav2Vec2 model
model_name = "ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000"
model = Wav2Vec2ForCTC.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name)
model.eval()

SAMPLE_RATE = 16000

def format_timestamp(seconds, format_type="srt"):
    """Convert seconds to SRT or WebVTT timestamp format"""
    td = timedelta(seconds=seconds)
    hours = td.seconds // 3600
    minutes = (td.seconds % 3600) // 60
    seconds = td.seconds % 60
    milliseconds = round(td.microseconds / 1000)
    
    if format_type == "srt":
        return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
    else:  # webvtt
        return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"

def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"):
    """Create SRT or WebVTT subtitle file"""
    with open(output_path, 'w', encoding='utf-8') as f:
        if format_type == "vtt":
            f.write("WEBVTT\n\n")
        
        for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
            if format_type == "srt":
                f.write(f"{i}\n")
                f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n")
                f.write(f"{text}\n\n")
            else:
                f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n")
                f.write(f"{text}\n\n")

def create_preview_html(audio_path, vtt_path):
    """Create an HTML preview file for audio with subtitles"""
    static_dir = Path("static")
    static_dir.mkdir(exist_ok=True)
    
    # Copy files to static directory with friendly names
    audio_filename = Path(audio_path).name
    vtt_filename = Path(vtt_path).name
    new_audio_path = static_dir / audio_filename
    new_vtt_path = static_dir / vtt_filename
    
    shutil.copy2(audio_path, new_audio_path)
    shutil.copy2(vtt_path, new_vtt_path)
    
    # Read HTML template
    with open("player.html", "r") as f:
        html_content = f.read()
    
    # Replace placeholders
    html_content = html_content.replace("{{ audio_path }}", f"static/{audio_filename}")
    html_content = html_content.replace("{{ vtt_path }}", f"static/{vtt_filename}")
    
    # Save preview HTML
    preview_path = static_dir / "preview.html"
    with open(preview_path, "w") as f:
        f.write(html_content)
    
    return str(preview_path)

def transcribe_with_vad(audio_path):
    # Load and resample audio to 16kHz mono
    wav, sr = torchaudio.load(audio_path)
    if sr != SAMPLE_RATE:
        wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
    wav = wav.mean(dim=0)  # convert to mono
    wav_np = wav.numpy()

    # Get speech timestamps using Silero VAD
    speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE)
    if not speech_timestamps:
        return "No speech detected.", None, None, None

    timestamps_with_text = []
    transcriptions = []
    
    for ts in speech_timestamps:
        start, end = ts['start'], ts['end']
        segment = wav[start:end]
        if segment.dim() > 1:
            segment = segment.squeeze()

        inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
        with torch.no_grad():
            logits = model(**inputs).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcription = processor.decode(predicted_ids[0])
        transcriptions.append(transcription)
        timestamps_with_text.append((start, end, transcription))

    # Generate subtitle files
    base_path = os.path.splitext(audio_path)[0]
    srt_path = f"{base_path}.srt"
    vtt_path = f"{base_path}.vtt"
    
    create_subtitle_file(timestamps_with_text, srt_path, "srt")
    create_subtitle_file(timestamps_with_text, vtt_path, "vtt")
    
    # Create preview HTML
    preview_html = create_preview_html(audio_path, vtt_path)

    return " ".join(transcriptions), srt_path, vtt_path, preview_html

# Gradio Interface
demo = gr.Interface(
    fn=transcribe_with_vad,
    inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload or Record"),
    outputs=[
        gr.Textbox(label="Transcription"),
        gr.File(label="SRT Subtitle File"),
        gr.File(label="WebVTT Subtitle File"),
        gr.HTML(label="Preview Player")
    ],
    title="Smart Speech-to-Text with VAD and Subtitles",
    description="Transcribe long audio using ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000 and Silero VAD. Generates SRT and WebVTT subtitle files."
)

if __name__ == "__main__":
    demo.launch(share=True)