ganga4364 commited on
Commit
f78a75f
·
verified ·
1 Parent(s): 0bc59b4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
6
+ from datetime import timedelta
7
+ import os
8
+ import shutil
9
+ from pathlib import Path
10
+
11
+ # Load Silero VAD
12
+ vad_model, utils = torch.hub.load(
13
+ repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True
14
+ )
15
+ (get_speech_ts, _, _, _, _) = utils
16
+
17
+ # Load Wav2Vec2 model
18
+ model_name = "ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000"
19
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
20
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
21
+ model.eval()
22
+
23
+ SAMPLE_RATE = 16000
24
+
25
+ def format_timestamp(seconds, format_type="srt"):
26
+ """Convert seconds to SRT or WebVTT timestamp format"""
27
+ td = timedelta(seconds=seconds)
28
+ hours = td.seconds // 3600
29
+ minutes = (td.seconds % 3600) // 60
30
+ seconds = td.seconds % 60
31
+ milliseconds = round(td.microseconds / 1000)
32
+
33
+ if format_type == "srt":
34
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"
35
+ else: # webvtt
36
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
37
+
38
+ def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"):
39
+ """Create SRT or WebVTT subtitle file"""
40
+ with open(output_path, 'w', encoding='utf-8') as f:
41
+ if format_type == "vtt":
42
+ f.write("WEBVTT\n\n")
43
+
44
+ for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
45
+ if format_type == "srt":
46
+ f.write(f"{i}\n")
47
+ f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n")
48
+ f.write(f"{text}\n\n")
49
+ else:
50
+ f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n")
51
+ f.write(f"{text}\n\n")
52
+
53
+ def create_preview_html(audio_path, vtt_path):
54
+ """Create an HTML preview file for audio with subtitles"""
55
+ static_dir = Path("static")
56
+ static_dir.mkdir(exist_ok=True)
57
+
58
+ # Copy files to static directory with friendly names
59
+ audio_filename = Path(audio_path).name
60
+ vtt_filename = Path(vtt_path).name
61
+ new_audio_path = static_dir / audio_filename
62
+ new_vtt_path = static_dir / vtt_filename
63
+
64
+ shutil.copy2(audio_path, new_audio_path)
65
+ shutil.copy2(vtt_path, new_vtt_path)
66
+
67
+ # Read HTML template
68
+ with open("templates/player.html", "r") as f:
69
+ html_content = f.read()
70
+
71
+ # Replace placeholders
72
+ html_content = html_content.replace("{{ audio_path }}", f"static/{audio_filename}")
73
+ html_content = html_content.replace("{{ vtt_path }}", f"static/{vtt_filename}")
74
+
75
+ # Save preview HTML
76
+ preview_path = static_dir / "preview.html"
77
+ with open(preview_path, "w") as f:
78
+ f.write(html_content)
79
+
80
+ return str(preview_path)
81
+
82
+ def transcribe_with_vad(audio_path):
83
+ # Load and resample audio to 16kHz mono
84
+ wav, sr = torchaudio.load(audio_path)
85
+ if sr != SAMPLE_RATE:
86
+ wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
87
+ wav = wav.mean(dim=0) # convert to mono
88
+ wav_np = wav.numpy()
89
+
90
+ # Get speech timestamps using Silero VAD
91
+ speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE)
92
+ if not speech_timestamps:
93
+ return "No speech detected.", None, None, None
94
+
95
+ timestamps_with_text = []
96
+ transcriptions = []
97
+
98
+ for ts in speech_timestamps:
99
+ start, end = ts['start'], ts['end']
100
+ segment = wav[start:end]
101
+ if segment.dim() > 1:
102
+ segment = segment.squeeze()
103
+
104
+ inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
105
+ with torch.no_grad():
106
+ logits = model(**inputs).logits
107
+ predicted_ids = torch.argmax(logits, dim=-1)
108
+ transcription = processor.decode(predicted_ids[0])
109
+ transcriptions.append(transcription)
110
+ timestamps_with_text.append((start, end, transcription))
111
+
112
+ # Generate subtitle files
113
+ base_path = os.path.splitext(audio_path)[0]
114
+ srt_path = f"{base_path}.srt"
115
+ vtt_path = f"{base_path}.vtt"
116
+
117
+ create_subtitle_file(timestamps_with_text, srt_path, "srt")
118
+ create_subtitle_file(timestamps_with_text, vtt_path, "vtt")
119
+
120
+ # Create preview HTML
121
+ preview_html = create_preview_html(audio_path, vtt_path)
122
+
123
+ return " ".join(transcriptions), srt_path, vtt_path, preview_html
124
+
125
+ # Gradio Interface
126
+ demo = gr.Interface(
127
+ fn=transcribe_with_vad,
128
+ inputs=gr.Audio(sources=["microphone", "upload"], type="filepath", label="Upload or Record"),
129
+ outputs=[
130
+ gr.Textbox(label="Transcription"),
131
+ gr.File(label="SRT Subtitle File"),
132
+ gr.File(label="WebVTT Subtitle File"),
133
+ gr.HTML(label="Preview Player")
134
+ ],
135
+ title="Smart Speech-to-Text with VAD and Subtitles",
136
+ description="Transcribe long audio using ganga4364/Garchen_Rinpoche-wav2vec2-Checkpoint-19000 and Silero VAD. Generates SRT and WebVTT subtitle files."
137
+ )
138
+
139
+ if __name__ == "__main__":
140
+ demo.launch(share=True)