Kr08 commited on
Commit
5b5fc60
·
verified ·
1 Parent(s): e1bae5b

Create chunkedTranscriber.py

Browse files
Files changed (1) hide show
  1. chunkedTranscriber.py +401 -0
chunkedTranscriber.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import sys
4
+ import time
5
+ import torch
6
+ import torchaudio
7
+ import numpy as np
8
+ from scipy.signal import resample
9
+ from pyannote.audio import Pipeline
10
+ from dotenv import load_dotenv
11
+ load_dotenv()
12
+ import logging
13
+ import time
14
+ from difflib import SequenceMatcher
15
+ from transformers import Wav2Vec2ForSequenceClassification, AutoFeatureExtractor, Wav2Vec2ForCTC, AutoProcessor, AutoTokenizer, AutoModelForSeq2SeqLM
16
+ from difflib import SequenceMatcher
17
+ import gc
18
+
19
+ class ChunkedTranscriber:
20
+ def __init__(self, chunk_size=5, overlap=1, sample_rate=16000):
21
+ self.chunk_size = chunk_size
22
+ self.overlap = overlap
23
+ self.sample_rate = sample_rate
24
+ self.previous_text = ""
25
+ self.previous_lang = None
26
+ self.speaker_diarization_pipeline = self.load_speaker_diarization_pipeline()
27
+
28
+ def load_speaker_diarization_pipeline(self):
29
+ """
30
+ Load the pre-trained speaker diarization pipeline from pyannote-audio.
31
+ """
32
+ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
33
+ return pipeline
34
+
35
+ def diarize_audio(self, audio_path):
36
+ """
37
+ Perform speaker diarization on the input audio.
38
+ """
39
+ diarization_result = self.speaker_diarization_pipeline({"uri": "audio", "audio": audio_path})
40
+ return diarization_result
41
+
42
+ def load_lid_mms(self):
43
+ model_id = "facebook/mms-lid-256"
44
+ processor = AutoFeatureExtractor.from_pretrained(model_id)
45
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(model_id)
46
+ return processor, model
47
+
48
+
49
+ def language_identification(self, model, processor, chunk, device="cuda"):
50
+ inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
51
+ model.to(device)
52
+ inputs.to(device)
53
+ with torch.no_grad():
54
+ outputs = model(**inputs).logits
55
+
56
+ lang_id = torch.argmax(outputs, dim=-1)[0].item()
57
+ detected_lang = model.config.id2label[lang_id]
58
+ del model
59
+ del inputs
60
+ torch.cuda.empty_cache()
61
+ gc.collect()
62
+ return detected_lang
63
+
64
+
65
+ def load_mms(self) :
66
+ model_id = "facebook/mms-1b-all"
67
+ processor = AutoProcessor.from_pretrained(model_id)
68
+ model = Wav2Vec2ForCTC.from_pretrained(model_id)
69
+ return model, processor
70
+
71
+
72
+ def mms_transcription(self, model, processor, chunk, device="cuda"):
73
+
74
+ inputs = processor(chunk, sampling_rate=16_000, return_tensors="pt")
75
+ model.to(device)
76
+ inputs.to(device)
77
+ with torch.no_grad():
78
+ outputs = model(**inputs).logits
79
+
80
+ ids = torch.argmax(outputs, dim=-1)[0]
81
+ transcription = processor.decode(ids)
82
+ del model
83
+ del inputs
84
+ torch.cuda.empty_cache()
85
+ gc.collect()
86
+ return transcription
87
+
88
+
89
+ def load_T2T_translation_model(self) :
90
+ model_id = "facebook/nllb-200-distilled-600M"
91
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
92
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
93
+ return model, tokenizer
94
+
95
+
96
+ def text2text_translation(self, translation_model, translation_tokenizer, transcript, device="cuda"):
97
+ # model, tokenizer = load_translation_model()
98
+
99
+ tokenized_inputs = translation_tokenizer(transcript, return_tensors='pt')
100
+ translation_model.to(device)
101
+ tokenized_inputs.to(device)
102
+ translated_tokens = translation_model.generate(**tokenized_inputs,
103
+ forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
104
+ max_length=100)
105
+ del translation_model
106
+ del tokenized_inputs
107
+ torch.cuda.empty_cache()
108
+ gc.collect()
109
+ return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
110
+
111
+ def preprocess_audio(self, audio):
112
+ """
113
+ Create overlapping chunks with improved timing logic
114
+ """
115
+ chunk_samples = int(self.chunk_size * self.sample_rate)
116
+ overlap_samples = int(self.overlap * self.sample_rate)
117
+
118
+ chunks_with_times = []
119
+ start_idx = 0
120
+
121
+ while start_idx < len(audio):
122
+ end_idx = min(start_idx + chunk_samples, len(audio))
123
+
124
+ # Add padding for first chunk
125
+ if start_idx == 0:
126
+ chunk = audio[start_idx:end_idx]
127
+ padding = torch.zeros(int(1 * self.sample_rate))
128
+ chunk = torch.cat([padding, chunk])
129
+ else:
130
+ # Include overlap from previous chunk
131
+ actual_start = max(0, start_idx - overlap_samples)
132
+ chunk = audio[actual_start:end_idx]
133
+
134
+ # Pad if necessary
135
+ if len(chunk) < chunk_samples:
136
+ chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
137
+
138
+ # Adjust time ranges to account for overlaps
139
+ chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap)
140
+ chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate)
141
+
142
+ chunks_with_times.append({
143
+ 'chunk': chunk,
144
+ 'start_time': start_idx / self.sample_rate,
145
+ 'end_time': end_idx / self.sample_rate,
146
+ 'transcribe_start': chunk_start_time,
147
+ 'transcribe_end': chunk_end_time
148
+ })
149
+
150
+ # Move to next chunk with smaller step size for better continuity
151
+ start_idx += (chunk_samples - overlap_samples)
152
+
153
+ return chunks_with_times
154
+
155
+
156
+ def merge_close_segments(self, results):
157
+ """
158
+ Merge segments that are close in time and have the same language
159
+ """
160
+ if not results:
161
+ return results
162
+
163
+ merged = []
164
+ current = results[0]
165
+
166
+ for next_segment in results[1:]:
167
+ # Skip empty segments
168
+ if not next_segment['text'].strip():
169
+ continue
170
+
171
+ # If segments are in the same language and close in time
172
+ if (current['detected_language'] == next_segment['detected_language'] and
173
+ abs(next_segment['start_time'] - current['end_time']) <= self.overlap):
174
+
175
+ # Merge the segments
176
+ current['text'] = current['text'] + ' ' + next_segment['text']
177
+ current['end_time'] = next_segment['end_time']
178
+ if 'translated' in current and 'translated' in next_segment:
179
+ current['translated'] = current['translated'] + ' ' + next_segment['translated']
180
+ else:
181
+ if current['text'].strip(): # Only add non-empty segments
182
+ merged.append(current)
183
+ current = next_segment
184
+
185
+ if current['text'].strip(): # Add the last segment if non-empty
186
+ merged.append(current)
187
+
188
+ return merged
189
+
190
+
191
+ def clean_overlapping_text(self, current_text, prev_text, current_lang, prev_lang, min_overlap=3):
192
+ """
193
+ Improved text cleaning with language awareness and better sentence boundary handling
194
+ """
195
+ if not prev_text or not current_text:
196
+ return current_text
197
+
198
+ # If languages are different, don't try to merge
199
+ if prev_lang and current_lang and prev_lang != current_lang:
200
+ return current_text
201
+
202
+ # Split into words
203
+ prev_words = prev_text.split()
204
+ curr_words = current_text.split()
205
+
206
+ if len(prev_words) < 2 or len(curr_words) < 2:
207
+ return current_text
208
+
209
+ # Find matching sequences at the end of prev_text and start of current_text
210
+ matcher = SequenceMatcher(None, prev_words, curr_words)
211
+ matches = list(matcher.get_matching_blocks())
212
+
213
+ # Look for significant overlaps
214
+ best_overlap = 0
215
+ overlap_size = 0
216
+
217
+ for match in matches:
218
+ # Check if the match is at the start of current text
219
+ if match.b == 0 and match.size >= min_overlap:
220
+ if match.size > overlap_size:
221
+ best_overlap = match.size
222
+ overlap_size = match.size
223
+
224
+ if best_overlap > 0:
225
+ # Remove overlapping content while preserving sentence integrity
226
+ cleaned_words = curr_words[best_overlap:]
227
+ if not cleaned_words: # If everything was overlapping
228
+ return ""
229
+ return ' '.join(cleaned_words).strip()
230
+
231
+ return current_text
232
+
233
+
234
+ def process_chunk(self, chunk_data, mms_model, mms_processor, translation_model=None, translation_tokenizer=None):
235
+ """
236
+ Process chunk with improved language handling
237
+ """
238
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
239
+
240
+ try:
241
+ print(f"\n\n Chunk shape: {chunk_data['chunk'].shape}")
242
+ # Language detection
243
+ lid_processor, lid_model = self.load_lid_mms()
244
+ lid_lang = self.language_identification(lid_model, lid_processor, chunk_data['chunk'])
245
+
246
+ # Configure processor
247
+ mms_processor.tokenizer.set_target_lang(lid_lang)
248
+ mms_model.load_adapter(lid_lang)
249
+
250
+ # Transcribe
251
+ inputs = mms_processor(chunk_data['chunk'], sampling_rate=self.sample_rate, return_tensors="pt")
252
+ inputs = inputs.to(device)
253
+ mms_model = mms_model.to(device)
254
+
255
+ with torch.no_grad():
256
+ outputs = mms_model(**inputs).logits
257
+
258
+ ids = torch.argmax(outputs, dim=-1)[0]
259
+ transcription = mms_processor.decode(ids)
260
+
261
+ # Clean overlapping text with language awareness
262
+ cleaned_transcription = self.clean_overlapping_text(
263
+ transcription,
264
+ self.previous_text,
265
+ lid_lang,
266
+ self.previous_lang,
267
+ min_overlap=3
268
+ )
269
+
270
+ # Update previous state
271
+ self.previous_text = transcription
272
+ self.previous_lang = lid_lang
273
+
274
+ if not cleaned_transcription.strip():
275
+ return None
276
+
277
+ result = {
278
+ 'start_time': chunk_data['start_time'],
279
+ 'end_time': chunk_data['end_time'],
280
+ 'text': cleaned_transcription,
281
+ 'detected_language': lid_lang
282
+ }
283
+
284
+ # Handle translation
285
+ if translation_model and translation_tokenizer and cleaned_transcription.strip():
286
+ translation = self.text2text_translation(
287
+ translation_model,
288
+ translation_tokenizer,
289
+ cleaned_transcription
290
+ )
291
+ result['translated'] = translation
292
+
293
+ return result
294
+
295
+ except Exception as e:
296
+ print(f"Error processing chunk: {str(e)}")
297
+ return None
298
+ finally:
299
+ torch.cuda.empty_cache()
300
+ gc.collect()
301
+
302
+
303
+ def translate_text(self, text, translation_model, translation_tokenizer, device):
304
+ """
305
+ Translate cleaned text using the provided translation model.
306
+ """
307
+ tokenized_inputs = translation_tokenizer(text, return_tensors='pt')
308
+ tokenized_inputs = tokenized_inputs.to(device)
309
+ translation_model = translation_model.to(device)
310
+
311
+ translated_tokens = translation_model.generate(
312
+ **tokenized_inputs,
313
+ forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
314
+ max_length=100
315
+ )
316
+
317
+ translation = translation_tokenizer.batch_decode(
318
+ translated_tokens,
319
+ skip_special_tokens=True
320
+ )[0]
321
+
322
+ del translation_model
323
+ del tokenized_inputs
324
+ torch.cuda.empty_cache()
325
+ gc.collect()
326
+ return translation
327
+
328
+
329
+
330
+ def transcribe_audio(self, audio_path, translate=False):
331
+ """
332
+ Main transcription function with improved segment merging
333
+ """
334
+ # Perform speaker diarization
335
+ diarization_result = self.diarize_audio(audio_path)
336
+
337
+ # Extract speaker segments
338
+ speaker_segments = []
339
+
340
+ for turn, _, speaker in diarization_result.itertracks(yield_label=True):
341
+ speaker_segments.append({
342
+ 'start_time': turn.start,
343
+ 'end_time': turn.end,
344
+ 'speaker': speaker
345
+ })
346
+ # print(f"\n\n Speaker Segments:\n{speaker_segments}\n")
347
+
348
+ audio = self.load_audio(audio_path)
349
+ chunks = self.preprocess_audio(audio)
350
+
351
+ mms_model, mms_processor = self.load_mms()
352
+ translation_model, translation_tokenizer = None, None
353
+ if translate:
354
+ translation_model, translation_tokenizer = self.load_T2T_translation_model()
355
+
356
+ # Process chunks
357
+ results = []
358
+ for chunk_data in chunks:
359
+ result = self.process_chunk(
360
+ chunk_data,
361
+ mms_model,
362
+ mms_processor,
363
+ translation_model,
364
+ translation_tokenizer
365
+ )
366
+ print(f"\n\nResult:\n{result}")
367
+ if result:
368
+ for segment in speaker_segments:
369
+ if int(segment['start_time']) <= int(chunk_data['start_time']) < int(segment['end_time']):
370
+ result['speaker'] = segment['speaker']
371
+ break
372
+ results.append(result)
373
+ # results.append(result)
374
+
375
+ # Merge close segments and clean up
376
+ merged_results = self.merge_close_segments(results)
377
+
378
+ return merged_results
379
+
380
+
381
+ def load_audio(self, audio_path):
382
+ """
383
+ Load and preprocess audio file.
384
+ """
385
+ waveform, sample_rate = torchaudio.load(audio_path)
386
+
387
+ # Convert to mono if stereo
388
+ if waveform.shape[0] > 1:
389
+ waveform = torch.mean(waveform, dim=0)
390
+ else:
391
+ waveform = waveform.squeeze(0)
392
+
393
+ # Resample if necessary
394
+ if sample_rate != self.sample_rate:
395
+ resampler = torchaudio.transforms.Resample(
396
+ orig_freq=sample_rate,
397
+ new_freq=self.sample_rate
398
+ )
399
+ waveform = resampler(waveform)
400
+
401
+ return waveform.float()