Create chunkedTranscriber.py
Browse files- chunkedTranscriber.py +401 -0
chunkedTranscriber.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
import numpy as np
|
8 |
+
from scipy.signal import resample
|
9 |
+
from pyannote.audio import Pipeline
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
load_dotenv()
|
12 |
+
import logging
|
13 |
+
import time
|
14 |
+
from difflib import SequenceMatcher
|
15 |
+
from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM
|
16 |
+
from difflib import SequenceMatcher
|
17 |
+
import gc
|
18 |
+
|
19 |
+
class ChunkedTranscriber:
|
20 |
+
def __init__(self, chunk_size=5, overlap=1, sample_rate=16000):
|
21 |
+
self.chunk_size = chunk_size
|
22 |
+
self.overlap = overlap
|
23 |
+
self.sample_rate = sample_rate
|
24 |
+
self.previous_text = ""
|
25 |
+
self.previous_lang = None
|
26 |
+
self.speaker_diarization_pipeline = self.load_speaker_diarization_pipeline()
|
27 |
+
|
28 |
+
def load_speaker_diarization_pipeline(self):
|
29 |
+
"""
|
30 |
+
Load the pre-trained speaker diarization pipeline from pyannote-audio.
|
31 |
+
"""
|
32 |
+
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
|
33 |
+
return pipeline
|
34 |
+
|
35 |
+
def diarize_audio(self, audio_path):
|
36 |
+
"""
|
37 |
+
Perform speaker diarization on the input audio.
|
38 |
+
"""
|
39 |
+
diarization_result = self.speaker_diarization_pipeline({"uri": "audio", "audio": audio_path})
|
40 |
+
return diarization_result
|
41 |
+
|
42 |
+
def load_lid_mms(self):
|
43 |
+
model_id = "facebook/mms-lid-256"
|
44 |
+
processor = AutoFeatureExtractor.from_pretrained(model_id)
|
45 |
+
model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
|
46 |
+
return processor, model
|
47 |
+
|
48 |
+
|
49 |
+
def language_identification(self, model, processor, chunk, device="cuda"):
|
50 |
+
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
|
51 |
+
model.to(device)
|
52 |
+
inputs.to(device)
|
53 |
+
with torch.no_grad():
|
54 |
+
outputs = model(**inputs).logits
|
55 |
+
|
56 |
+
lang_id = torch.argmax(outputs, dim=-1)[0].item()
|
57 |
+
detected_lang = model.config.id2label[lang_id]
|
58 |
+
del model
|
59 |
+
del inputs
|
60 |
+
torch.cuda.empty_cache()
|
61 |
+
gc.collect()
|
62 |
+
return detected_lang
|
63 |
+
|
64 |
+
|
65 |
+
def load_mms(self) :
|
66 |
+
model_id = "facebook/mms-1b-all"
|
67 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
68 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_id)
|
69 |
+
return model, processor
|
70 |
+
|
71 |
+
|
72 |
+
def mms_transcription(self, model, processor, chunk, device="cuda"):
|
73 |
+
|
74 |
+
inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
|
75 |
+
model.to(device)
|
76 |
+
inputs.to(device)
|
77 |
+
with torch.no_grad():
|
78 |
+
outputs = model(**inputs).logits
|
79 |
+
|
80 |
+
ids = torch.argmax(outputs, dim=-1)[0]
|
81 |
+
transcription = processor.decode(ids)
|
82 |
+
del model
|
83 |
+
del inputs
|
84 |
+
torch.cuda.empty_cache()
|
85 |
+
gc.collect()
|
86 |
+
return transcription
|
87 |
+
|
88 |
+
|
89 |
+
def load_T2T_translation_model(self) :
|
90 |
+
model_id = "facebook/nllb-200-distilled-600M"
|
91 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
92 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
93 |
+
return model, tokenizer
|
94 |
+
|
95 |
+
|
96 |
+
def text2text_translation(self, translation_model, translation_tokenizer, transcript, device="cuda"):
|
97 |
+
# model, tokenizer = load_translation_model()
|
98 |
+
|
99 |
+
tokenized_inputs = translation_tokenizer(transcript, return_tensors='pt')
|
100 |
+
translation_model.to(device)
|
101 |
+
tokenized_inputs.to(device)
|
102 |
+
translated_tokens = translation_model.generate(**tokenized_inputs,
|
103 |
+
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
|
104 |
+
max_length=100)
|
105 |
+
del translation_model
|
106 |
+
del tokenized_inputs
|
107 |
+
torch.cuda.empty_cache()
|
108 |
+
gc.collect()
|
109 |
+
return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
110 |
+
|
111 |
+
def preprocess_audio(self, audio):
|
112 |
+
"""
|
113 |
+
Create overlapping chunks with improved timing logic
|
114 |
+
"""
|
115 |
+
chunk_samples = int(self.chunk_size * self.sample_rate)
|
116 |
+
overlap_samples = int(self.overlap * self.sample_rate)
|
117 |
+
|
118 |
+
chunks_with_times = []
|
119 |
+
start_idx = 0
|
120 |
+
|
121 |
+
while start_idx < len(audio):
|
122 |
+
end_idx = min(start_idx + chunk_samples, len(audio))
|
123 |
+
|
124 |
+
# Add padding for first chunk
|
125 |
+
if start_idx == 0:
|
126 |
+
chunk = audio[start_idx:end_idx]
|
127 |
+
padding = torch.zeros(int(1 * self.sample_rate))
|
128 |
+
chunk = torch.cat([padding, chunk])
|
129 |
+
else:
|
130 |
+
# Include overlap from previous chunk
|
131 |
+
actual_start = max(0, start_idx - overlap_samples)
|
132 |
+
chunk = audio[actual_start:end_idx]
|
133 |
+
|
134 |
+
# Pad if necessary
|
135 |
+
if len(chunk) < chunk_samples:
|
136 |
+
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
|
137 |
+
|
138 |
+
# Adjust time ranges to account for overlaps
|
139 |
+
chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap)
|
140 |
+
chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate)
|
141 |
+
|
142 |
+
chunks_with_times.append({
|
143 |
+
'chunk': chunk,
|
144 |
+
'start_time': start_idx / self.sample_rate,
|
145 |
+
'end_time': end_idx / self.sample_rate,
|
146 |
+
'transcribe_start': chunk_start_time,
|
147 |
+
'transcribe_end': chunk_end_time
|
148 |
+
})
|
149 |
+
|
150 |
+
# Move to next chunk with smaller step size for better continuity
|
151 |
+
start_idx += (chunk_samples - overlap_samples)
|
152 |
+
|
153 |
+
return chunks_with_times
|
154 |
+
|
155 |
+
|
156 |
+
def merge_close_segments(self, results):
|
157 |
+
"""
|
158 |
+
Merge segments that are close in time and have the same language
|
159 |
+
"""
|
160 |
+
if not results:
|
161 |
+
return results
|
162 |
+
|
163 |
+
merged = []
|
164 |
+
current = results[0]
|
165 |
+
|
166 |
+
for next_segment in results[1:]:
|
167 |
+
# Skip empty segments
|
168 |
+
if not next_segment['text'].strip():
|
169 |
+
continue
|
170 |
+
|
171 |
+
# If segments are in the same language and close in time
|
172 |
+
if (current['detected_language'] == next_segment['detected_language'] and
|
173 |
+
abs(next_segment['start_time'] - current['end_time']) <= self.overlap):
|
174 |
+
|
175 |
+
# Merge the segments
|
176 |
+
current['text'] = current['text'] + ' ' + next_segment['text']
|
177 |
+
current['end_time'] = next_segment['end_time']
|
178 |
+
if 'translated' in current and 'translated' in next_segment:
|
179 |
+
current['translated'] = current['translated'] + ' ' + next_segment['translated']
|
180 |
+
else:
|
181 |
+
if current['text'].strip(): # Only add non-empty segments
|
182 |
+
merged.append(current)
|
183 |
+
current = next_segment
|
184 |
+
|
185 |
+
if current['text'].strip(): # Add the last segment if non-empty
|
186 |
+
merged.append(current)
|
187 |
+
|
188 |
+
return merged
|
189 |
+
|
190 |
+
|
191 |
+
def clean_overlapping_text(self, current_text, prev_text, current_lang, prev_lang, min_overlap=3):
|
192 |
+
"""
|
193 |
+
Improved text cleaning with language awareness and better sentence boundary handling
|
194 |
+
"""
|
195 |
+
if not prev_text or not current_text:
|
196 |
+
return current_text
|
197 |
+
|
198 |
+
# If languages are different, don't try to merge
|
199 |
+
if prev_lang and current_lang and prev_lang != current_lang:
|
200 |
+
return current_text
|
201 |
+
|
202 |
+
# Split into words
|
203 |
+
prev_words = prev_text.split()
|
204 |
+
curr_words = current_text.split()
|
205 |
+
|
206 |
+
if len(prev_words) < 2 or len(curr_words) < 2:
|
207 |
+
return current_text
|
208 |
+
|
209 |
+
# Find matching sequences at the end of prev_text and start of current_text
|
210 |
+
matcher = SequenceMatcher(None, prev_words, curr_words)
|
211 |
+
matches = list(matcher.get_matching_blocks())
|
212 |
+
|
213 |
+
# Look for significant overlaps
|
214 |
+
best_overlap = 0
|
215 |
+
overlap_size = 0
|
216 |
+
|
217 |
+
for match in matches:
|
218 |
+
# Check if the match is at the start of current text
|
219 |
+
if match.b == 0 and match.size >= min_overlap:
|
220 |
+
if match.size > overlap_size:
|
221 |
+
best_overlap = match.size
|
222 |
+
overlap_size = match.size
|
223 |
+
|
224 |
+
if best_overlap > 0:
|
225 |
+
# Remove overlapping content while preserving sentence integrity
|
226 |
+
cleaned_words = curr_words[best_overlap:]
|
227 |
+
if not cleaned_words: # If everything was overlapping
|
228 |
+
return ""
|
229 |
+
return ' '.join(cleaned_words).strip()
|
230 |
+
|
231 |
+
return current_text
|
232 |
+
|
233 |
+
|
234 |
+
def process_chunk(self, chunk_data, mms_model, mms_processor, translation_model=None, translation_tokenizer=None):
|
235 |
+
"""
|
236 |
+
Process chunk with improved language handling
|
237 |
+
"""
|
238 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
239 |
+
|
240 |
+
try:
|
241 |
+
print(f"\n\n Chunk shape: {chunk_data['chunk'].shape}")
|
242 |
+
# Language detection
|
243 |
+
lid_processor, lid_model = self.load_lid_mms()
|
244 |
+
lid_lang = self.language_identification(lid_model, lid_processor, chunk_data['chunk'])
|
245 |
+
|
246 |
+
# Configure processor
|
247 |
+
mms_processor.tokenizer.set_target_lang(lid_lang)
|
248 |
+
mms_model.load_adapter(lid_lang)
|
249 |
+
|
250 |
+
# Transcribe
|
251 |
+
inputs = mms_processor(chunk_data['chunk'], sampling_rate=self.sample_rate, return_tensors="pt")
|
252 |
+
inputs = inputs.to(device)
|
253 |
+
mms_model = mms_model.to(device)
|
254 |
+
|
255 |
+
with torch.no_grad():
|
256 |
+
outputs = mms_model(**inputs).logits
|
257 |
+
|
258 |
+
ids = torch.argmax(outputs, dim=-1)[0]
|
259 |
+
transcription = mms_processor.decode(ids)
|
260 |
+
|
261 |
+
# Clean overlapping text with language awareness
|
262 |
+
cleaned_transcription = self.clean_overlapping_text(
|
263 |
+
transcription,
|
264 |
+
self.previous_text,
|
265 |
+
lid_lang,
|
266 |
+
self.previous_lang,
|
267 |
+
min_overlap=3
|
268 |
+
)
|
269 |
+
|
270 |
+
# Update previous state
|
271 |
+
self.previous_text = transcription
|
272 |
+
self.previous_lang = lid_lang
|
273 |
+
|
274 |
+
if not cleaned_transcription.strip():
|
275 |
+
return None
|
276 |
+
|
277 |
+
result = {
|
278 |
+
'start_time': chunk_data['start_time'],
|
279 |
+
'end_time': chunk_data['end_time'],
|
280 |
+
'text': cleaned_transcription,
|
281 |
+
'detected_language': lid_lang
|
282 |
+
}
|
283 |
+
|
284 |
+
# Handle translation
|
285 |
+
if translation_model and translation_tokenizer and cleaned_transcription.strip():
|
286 |
+
translation = self.text2text_translation(
|
287 |
+
translation_model,
|
288 |
+
translation_tokenizer,
|
289 |
+
cleaned_transcription
|
290 |
+
)
|
291 |
+
result['translated'] = translation
|
292 |
+
|
293 |
+
return result
|
294 |
+
|
295 |
+
except Exception as e:
|
296 |
+
print(f"Error processing chunk: {str(e)}")
|
297 |
+
return None
|
298 |
+
finally:
|
299 |
+
torch.cuda.empty_cache()
|
300 |
+
gc.collect()
|
301 |
+
|
302 |
+
|
303 |
+
def translate_text(self, text, translation_model, translation_tokenizer, device):
|
304 |
+
"""
|
305 |
+
Translate cleaned text using the provided translation model.
|
306 |
+
"""
|
307 |
+
tokenized_inputs = translation_tokenizer(text, return_tensors='pt')
|
308 |
+
tokenized_inputs = tokenized_inputs.to(device)
|
309 |
+
translation_model = translation_model.to(device)
|
310 |
+
|
311 |
+
translated_tokens = translation_model.generate(
|
312 |
+
**tokenized_inputs,
|
313 |
+
forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
|
314 |
+
max_length=100
|
315 |
+
)
|
316 |
+
|
317 |
+
translation = translation_tokenizer.batch_decode(
|
318 |
+
translated_tokens,
|
319 |
+
skip_special_tokens=True
|
320 |
+
)[0]
|
321 |
+
|
322 |
+
del translation_model
|
323 |
+
del tokenized_inputs
|
324 |
+
torch.cuda.empty_cache()
|
325 |
+
gc.collect()
|
326 |
+
return translation
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
def transcribe_audio(self, audio_path, translate=False):
|
331 |
+
"""
|
332 |
+
Main transcription function with improved segment merging
|
333 |
+
"""
|
334 |
+
# Perform speaker diarization
|
335 |
+
diarization_result = self.diarize_audio(audio_path)
|
336 |
+
|
337 |
+
# Extract speaker segments
|
338 |
+
speaker_segments = []
|
339 |
+
|
340 |
+
for turn, _, speaker in diarization_result.itertracks(yield_label=True):
|
341 |
+
speaker_segments.append({
|
342 |
+
'start_time': turn.start,
|
343 |
+
'end_time': turn.end,
|
344 |
+
'speaker': speaker
|
345 |
+
})
|
346 |
+
# print(f"\n\n Speaker Segments:\n{speaker_segments}\n")
|
347 |
+
|
348 |
+
audio = self.load_audio(audio_path)
|
349 |
+
chunks = self.preprocess_audio(audio)
|
350 |
+
|
351 |
+
mms_model, mms_processor = self.load_mms()
|
352 |
+
translation_model, translation_tokenizer = None, None
|
353 |
+
if translate:
|
354 |
+
translation_model, translation_tokenizer = self.load_T2T_translation_model()
|
355 |
+
|
356 |
+
# Process chunks
|
357 |
+
results = []
|
358 |
+
for chunk_data in chunks:
|
359 |
+
result = self.process_chunk(
|
360 |
+
chunk_data,
|
361 |
+
mms_model,
|
362 |
+
mms_processor,
|
363 |
+
translation_model,
|
364 |
+
translation_tokenizer
|
365 |
+
)
|
366 |
+
print(f"\n\nResult:\n{result}")
|
367 |
+
if result:
|
368 |
+
for segment in speaker_segments:
|
369 |
+
if int(segment['start_time']) <= int(chunk_data['start_time']) < int(segment['end_time']):
|
370 |
+
result['speaker'] = segment['speaker']
|
371 |
+
break
|
372 |
+
results.append(result)
|
373 |
+
# results.append(result)
|
374 |
+
|
375 |
+
# Merge close segments and clean up
|
376 |
+
merged_results = self.merge_close_segments(results)
|
377 |
+
|
378 |
+
return merged_results
|
379 |
+
|
380 |
+
|
381 |
+
def load_audio(self, audio_path):
|
382 |
+
"""
|
383 |
+
Load and preprocess audio file.
|
384 |
+
"""
|
385 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
386 |
+
|
387 |
+
# Convert to mono if stereo
|
388 |
+
if waveform.shape[0] > 1:
|
389 |
+
waveform = torch.mean(waveform, dim=0)
|
390 |
+
else:
|
391 |
+
waveform = waveform.squeeze(0)
|
392 |
+
|
393 |
+
# Resample if necessary
|
394 |
+
if sample_rate != self.sample_rate:
|
395 |
+
resampler = torchaudio.transforms.Resample(
|
396 |
+
orig_freq=sample_rate,
|
397 |
+
new_freq=self.sample_rate
|
398 |
+
)
|
399 |
+
waveform = resampler(waveform)
|
400 |
+
|
401 |
+
return waveform.float()
|