Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,8 @@ import logging
|
|
13 |
|
14 |
# Constants and Configuration
|
15 |
SAMPLE_RATE = 16000
|
|
|
|
|
16 |
MODEL_NAME = "openpecha/general_stt_base_model"
|
17 |
|
18 |
title = "# Tibetan Speech-to-Text with Subtitles"
|
@@ -20,7 +22,7 @@ title = "# Tibetan Speech-to-Text with Subtitles"
|
|
20 |
description = """
|
21 |
This application transcribes Tibetan audio files and generates subtitles using:
|
22 |
- Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings
|
23 |
-
-
|
24 |
- Generates both SRT and WebVTT subtitle formats
|
25 |
"""
|
26 |
|
@@ -33,23 +35,17 @@ css = """
|
|
33 |
.player-container audio {width: 100%;}
|
34 |
"""
|
35 |
|
36 |
-
# Initialize
|
37 |
-
def
|
38 |
-
# Load Silero VAD
|
39 |
-
vad_model, utils = torch.hub.load(
|
40 |
-
repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True
|
41 |
-
)
|
42 |
-
get_speech_ts = utils[0]
|
43 |
-
|
44 |
# Load Wav2Vec2 model
|
45 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
|
46 |
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
|
47 |
model.eval()
|
48 |
|
49 |
-
return
|
50 |
|
51 |
-
# Initialize
|
52 |
-
|
53 |
|
54 |
def format_timestamp(seconds, format_type="srt"):
|
55 |
"""Convert seconds to SRT or WebVTT timestamp format"""
|
@@ -73,10 +69,10 @@ def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"):
|
|
73 |
for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
|
74 |
if format_type == "srt":
|
75 |
f.write(f"{i}\n")
|
76 |
-
f.write(f"{format_timestamp(start_time
|
77 |
f.write(f"{text}\n\n")
|
78 |
else:
|
79 |
-
f.write(f"{format_timestamp(start_time
|
80 |
f.write(f"{text}\n\n")
|
81 |
|
82 |
def build_html_output(s: str, style: str = "result_item_success"):
|
@@ -127,35 +123,46 @@ def process_audio(audio_path: str):
|
|
127 |
if sr != SAMPLE_RATE:
|
128 |
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
|
129 |
wav = wav.mean(dim=0) # convert to mono
|
130 |
-
wav_np = wav.numpy()
|
131 |
-
|
132 |
-
# Get speech timestamps using Silero VAD
|
133 |
-
speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE)
|
134 |
-
if not speech_timestamps:
|
135 |
-
return (
|
136 |
-
build_html_output("No speech detected", "result_item_error"),
|
137 |
-
None,
|
138 |
-
None,
|
139 |
-
"",
|
140 |
-
"",
|
141 |
-
)
|
142 |
|
|
|
|
|
143 |
timestamps_with_text = []
|
144 |
transcriptions = []
|
145 |
|
146 |
-
for
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
with torch.no_grad():
|
154 |
logits = model(**inputs).logits
|
155 |
predicted_ids = torch.argmax(logits, dim=-1)
|
156 |
transcription = processor.decode(predicted_ids[0])
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
# Generate subtitle files
|
161 |
base_path = os.path.splitext(audio_path)[0]
|
@@ -238,4 +245,4 @@ with demo:
|
|
238 |
if __name__ == "__main__":
|
239 |
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
240 |
logging.basicConfig(format=formatter, level=logging.INFO)
|
241 |
-
demo.launch(share=True)
|
|
|
13 |
|
14 |
# Constants and Configuration
|
15 |
SAMPLE_RATE = 16000
|
16 |
+
CHUNK_SECONDS = 30 # Split audio into 30-second chunks
|
17 |
+
CHUNK_SAMPLES = SAMPLE_RATE * CHUNK_SECONDS
|
18 |
MODEL_NAME = "openpecha/general_stt_base_model"
|
19 |
|
20 |
title = "# Tibetan Speech-to-Text with Subtitles"
|
|
|
22 |
description = """
|
23 |
This application transcribes Tibetan audio files and generates subtitles using:
|
24 |
- Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings
|
25 |
+
- 30-second fixed chunking for long audio processing
|
26 |
- Generates both SRT and WebVTT subtitle formats
|
27 |
"""
|
28 |
|
|
|
35 |
.player-container audio {width: 100%;}
|
36 |
"""
|
37 |
|
38 |
+
# Initialize model
|
39 |
+
def init_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# Load Wav2Vec2 model
|
41 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
|
42 |
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
|
43 |
model.eval()
|
44 |
|
45 |
+
return model, processor
|
46 |
|
47 |
+
# Initialize model globally
|
48 |
+
model, processor = init_model()
|
49 |
|
50 |
def format_timestamp(seconds, format_type="srt"):
|
51 |
"""Convert seconds to SRT or WebVTT timestamp format"""
|
|
|
69 |
for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
|
70 |
if format_type == "srt":
|
71 |
f.write(f"{i}\n")
|
72 |
+
f.write(f"{format_timestamp(start_time)} --> {format_timestamp(end_time)}\n")
|
73 |
f.write(f"{text}\n\n")
|
74 |
else:
|
75 |
+
f.write(f"{format_timestamp(start_time, 'vtt')} --> {format_timestamp(end_time, 'vtt')}\n")
|
76 |
f.write(f"{text}\n\n")
|
77 |
|
78 |
def build_html_output(s: str, style: str = "result_item_success"):
|
|
|
123 |
if sr != SAMPLE_RATE:
|
124 |
wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
|
125 |
wav = wav.mean(dim=0) # convert to mono
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
+
# Split audio into 30-second chunks
|
128 |
+
audio_length = wav.shape[0]
|
129 |
timestamps_with_text = []
|
130 |
transcriptions = []
|
131 |
|
132 |
+
for start_sample in range(0, audio_length, CHUNK_SAMPLES):
|
133 |
+
end_sample = min(start_sample + CHUNK_SAMPLES, audio_length)
|
134 |
+
|
135 |
+
# Convert sample positions to seconds
|
136 |
+
start_time = start_sample / SAMPLE_RATE
|
137 |
+
end_time = end_sample / SAMPLE_RATE
|
138 |
+
|
139 |
+
# Extract chunk
|
140 |
+
chunk = wav[start_sample:end_sample]
|
141 |
+
|
142 |
+
# Skip processing if chunk is too short (less than 0.5 seconds)
|
143 |
+
if chunk.shape[0] < 0.5 * SAMPLE_RATE:
|
144 |
+
continue
|
145 |
+
|
146 |
+
# Process chunk through model
|
147 |
+
inputs = processor(chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
|
148 |
with torch.no_grad():
|
149 |
logits = model(**inputs).logits
|
150 |
predicted_ids = torch.argmax(logits, dim=-1)
|
151 |
transcription = processor.decode(predicted_ids[0])
|
152 |
+
|
153 |
+
# Skip empty transcriptions
|
154 |
+
if transcription.strip():
|
155 |
+
transcriptions.append(transcription)
|
156 |
+
timestamps_with_text.append((start_time, end_time, transcription))
|
157 |
+
|
158 |
+
if not timestamps_with_text:
|
159 |
+
return (
|
160 |
+
build_html_output("No speech detected or recognized", "result_item_error"),
|
161 |
+
None,
|
162 |
+
None,
|
163 |
+
"",
|
164 |
+
"",
|
165 |
+
)
|
166 |
|
167 |
# Generate subtitle files
|
168 |
base_path = os.path.splitext(audio_path)[0]
|
|
|
245 |
if __name__ == "__main__":
|
246 |
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
247 |
logging.basicConfig(format=formatter, level=logging.INFO)
|
248 |
+
demo.launch(share=True)
|