ganga4364 commited on
Commit
2a3adfa
·
verified ·
1 Parent(s): 482d6e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -36
app.py CHANGED
@@ -13,6 +13,8 @@ import logging
13
 
14
  # Constants and Configuration
15
  SAMPLE_RATE = 16000
 
 
16
  MODEL_NAME = "openpecha/general_stt_base_model"
17
 
18
  title = "# Tibetan Speech-to-Text with Subtitles"
@@ -20,7 +22,7 @@ title = "# Tibetan Speech-to-Text with Subtitles"
20
  description = """
21
  This application transcribes Tibetan audio files and generates subtitles using:
22
  - Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings
23
- - Silero VAD for voice activity detection
24
  - Generates both SRT and WebVTT subtitle formats
25
  """
26
 
@@ -33,23 +35,17 @@ css = """
33
  .player-container audio {width: 100%;}
34
  """
35
 
36
- # Initialize models
37
- def init_models():
38
- # Load Silero VAD
39
- vad_model, utils = torch.hub.load(
40
- repo_or_dir='snakers4/silero-vad', model='silero_vad', trust_repo=True
41
- )
42
- get_speech_ts = utils[0]
43
-
44
  # Load Wav2Vec2 model
45
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
46
  processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
47
  model.eval()
48
 
49
- return vad_model, get_speech_ts, model, processor
50
 
51
- # Initialize models globally
52
- vad_model, get_speech_ts, model, processor = init_models()
53
 
54
  def format_timestamp(seconds, format_type="srt"):
55
  """Convert seconds to SRT or WebVTT timestamp format"""
@@ -73,10 +69,10 @@ def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"):
73
  for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
74
  if format_type == "srt":
75
  f.write(f"{i}\n")
76
- f.write(f"{format_timestamp(start_time/SAMPLE_RATE)} --> {format_timestamp(end_time/SAMPLE_RATE)}\n")
77
  f.write(f"{text}\n\n")
78
  else:
79
- f.write(f"{format_timestamp(start_time/SAMPLE_RATE, 'vtt')} --> {format_timestamp(end_time/SAMPLE_RATE, 'vtt')}\n")
80
  f.write(f"{text}\n\n")
81
 
82
  def build_html_output(s: str, style: str = "result_item_success"):
@@ -127,35 +123,46 @@ def process_audio(audio_path: str):
127
  if sr != SAMPLE_RATE:
128
  wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
129
  wav = wav.mean(dim=0) # convert to mono
130
- wav_np = wav.numpy()
131
-
132
- # Get speech timestamps using Silero VAD
133
- speech_timestamps = get_speech_ts(wav_np, vad_model, sampling_rate=SAMPLE_RATE)
134
- if not speech_timestamps:
135
- return (
136
- build_html_output("No speech detected", "result_item_error"),
137
- None,
138
- None,
139
- "",
140
- "",
141
- )
142
 
 
 
143
  timestamps_with_text = []
144
  transcriptions = []
145
 
146
- for ts in speech_timestamps:
147
- start, end = ts['start'], ts['end']
148
- segment = wav[start:end]
149
- if segment.dim() > 1:
150
- segment = segment.squeeze()
151
-
152
- inputs = processor(segment, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
 
 
 
 
 
 
 
 
 
153
  with torch.no_grad():
154
  logits = model(**inputs).logits
155
  predicted_ids = torch.argmax(logits, dim=-1)
156
  transcription = processor.decode(predicted_ids[0])
157
- transcriptions.append(transcription)
158
- timestamps_with_text.append((start, end, transcription))
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  # Generate subtitle files
161
  base_path = os.path.splitext(audio_path)[0]
@@ -238,4 +245,4 @@ with demo:
238
  if __name__ == "__main__":
239
  formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
240
  logging.basicConfig(format=formatter, level=logging.INFO)
241
- demo.launch(share=True)
 
13
 
14
  # Constants and Configuration
15
  SAMPLE_RATE = 16000
16
+ CHUNK_SECONDS = 30 # Split audio into 30-second chunks
17
+ CHUNK_SAMPLES = SAMPLE_RATE * CHUNK_SECONDS
18
  MODEL_NAME = "openpecha/general_stt_base_model"
19
 
20
  title = "# Tibetan Speech-to-Text with Subtitles"
 
22
  description = """
23
  This application transcribes Tibetan audio files and generates subtitles using:
24
  - Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings
25
+ - 30-second fixed chunking for long audio processing
26
  - Generates both SRT and WebVTT subtitle formats
27
  """
28
 
 
35
  .player-container audio {width: 100%;}
36
  """
37
 
38
+ # Initialize model
39
+ def init_model():
 
 
 
 
 
 
40
  # Load Wav2Vec2 model
41
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME)
42
  processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
43
  model.eval()
44
 
45
+ return model, processor
46
 
47
+ # Initialize model globally
48
+ model, processor = init_model()
49
 
50
  def format_timestamp(seconds, format_type="srt"):
51
  """Convert seconds to SRT or WebVTT timestamp format"""
 
69
  for i, (start_time, end_time, text) in enumerate(timestamps_with_text, 1):
70
  if format_type == "srt":
71
  f.write(f"{i}\n")
72
+ f.write(f"{format_timestamp(start_time)} --> {format_timestamp(end_time)}\n")
73
  f.write(f"{text}\n\n")
74
  else:
75
+ f.write(f"{format_timestamp(start_time, 'vtt')} --> {format_timestamp(end_time, 'vtt')}\n")
76
  f.write(f"{text}\n\n")
77
 
78
  def build_html_output(s: str, style: str = "result_item_success"):
 
123
  if sr != SAMPLE_RATE:
124
  wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav)
125
  wav = wav.mean(dim=0) # convert to mono
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # Split audio into 30-second chunks
128
+ audio_length = wav.shape[0]
129
  timestamps_with_text = []
130
  transcriptions = []
131
 
132
+ for start_sample in range(0, audio_length, CHUNK_SAMPLES):
133
+ end_sample = min(start_sample + CHUNK_SAMPLES, audio_length)
134
+
135
+ # Convert sample positions to seconds
136
+ start_time = start_sample / SAMPLE_RATE
137
+ end_time = end_sample / SAMPLE_RATE
138
+
139
+ # Extract chunk
140
+ chunk = wav[start_sample:end_sample]
141
+
142
+ # Skip processing if chunk is too short (less than 0.5 seconds)
143
+ if chunk.shape[0] < 0.5 * SAMPLE_RATE:
144
+ continue
145
+
146
+ # Process chunk through model
147
+ inputs = processor(chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
148
  with torch.no_grad():
149
  logits = model(**inputs).logits
150
  predicted_ids = torch.argmax(logits, dim=-1)
151
  transcription = processor.decode(predicted_ids[0])
152
+
153
+ # Skip empty transcriptions
154
+ if transcription.strip():
155
+ transcriptions.append(transcription)
156
+ timestamps_with_text.append((start_time, end_time, transcription))
157
+
158
+ if not timestamps_with_text:
159
+ return (
160
+ build_html_output("No speech detected or recognized", "result_item_error"),
161
+ None,
162
+ None,
163
+ "",
164
+ "",
165
+ )
166
 
167
  # Generate subtitle files
168
  base_path = os.path.splitext(audio_path)[0]
 
245
  if __name__ == "__main__":
246
  formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
247
  logging.basicConfig(format=formatter, level=logging.INFO)
248
+ demo.launch(share=True)