|  | import gc | 
					
						
						|  | import torch | 
					
						
						|  | import torchaudio | 
					
						
						|  | import numpy as np | 
					
						
						|  | from transformers import ( | 
					
						
						|  | Wav2Vec2ForSequenceClassification, | 
					
						
						|  | AutoFeatureExtractor, | 
					
						
						|  | Wav2Vec2ForCTC, | 
					
						
						|  | AutoProcessor, | 
					
						
						|  | AutoTokenizer, | 
					
						
						|  | AutoModelForSeq2SeqLM | 
					
						
						|  | ) | 
					
						
						|  | import logging | 
					
						
						|  | from difflib import SequenceMatcher | 
					
						
						|  |  | 
					
						
						|  | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  |  | 
					
						
						|  | class AudioProcessor: | 
					
						
						|  | def __init__(self, chunk_size=5, overlap=1, sample_rate=16000): | 
					
						
						|  | self.chunk_size = chunk_size | 
					
						
						|  | self.overlap = overlap | 
					
						
						|  | self.sample_rate = sample_rate | 
					
						
						|  | self.previous_text = "" | 
					
						
						|  | self.previous_lang = None | 
					
						
						|  | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 
					
						
						|  |  | 
					
						
						|  | def load_models(self): | 
					
						
						|  | """Load all required models""" | 
					
						
						|  | logger.info("Loading MMS models...") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | lid_processor = AutoFeatureExtractor.from_pretrained("facebook/mms-lid-256") | 
					
						
						|  | lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-256") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") | 
					
						
						|  | mms_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | translation_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | 
					
						
						|  | translation_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | 'lid': (lid_model, lid_processor), | 
					
						
						|  | 'mms': (mms_model, mms_processor), | 
					
						
						|  | 'translation': (translation_model, translation_tokenizer) | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def identify_language(self, audio_chunk, models): | 
					
						
						|  | """Identify language of audio chunk""" | 
					
						
						|  | lid_model, lid_processor = models['lid'] | 
					
						
						|  | inputs = lid_processor(audio_chunk, sampling_rate=16000, return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | outputs = lid_model(inputs.input_values.to(self.device)).logits | 
					
						
						|  | lang_id = torch.argmax(outputs, dim=-1)[0].item() | 
					
						
						|  | detected_lang = lid_model.config.id2label[lang_id] | 
					
						
						|  |  | 
					
						
						|  | return detected_lang | 
					
						
						|  |  | 
					
						
						|  | def transcribe_chunk(self, audio_chunk, language, models): | 
					
						
						|  | """Transcribe audio chunk""" | 
					
						
						|  | mms_model, mms_processor = models['mms'] | 
					
						
						|  |  | 
					
						
						|  | mms_processor.tokenizer.set_target_lang(language) | 
					
						
						|  | mms_model.load_adapter(language) | 
					
						
						|  |  | 
					
						
						|  | inputs = mms_processor(audio_chunk, sampling_rate=16000, return_tensors="pt") | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | outputs = mms_model(inputs.input_values.to(self.device)).logits | 
					
						
						|  | ids = torch.argmax(outputs, dim=-1)[0] | 
					
						
						|  | transcription = mms_processor.decode(ids) | 
					
						
						|  |  | 
					
						
						|  | return transcription | 
					
						
						|  |  | 
					
						
						|  | def translate_text(self, text, models): | 
					
						
						|  | """Translate text to English""" | 
					
						
						|  | translation_model, translation_tokenizer = models['translation'] | 
					
						
						|  |  | 
					
						
						|  | inputs = translation_tokenizer(text, return_tensors="pt") | 
					
						
						|  | inputs = inputs.to(self.device) | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | outputs = translation_model.generate( | 
					
						
						|  | **inputs, | 
					
						
						|  | forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), | 
					
						
						|  | max_length=100 | 
					
						
						|  | ) | 
					
						
						|  | translation = translation_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | 
					
						
						|  |  | 
					
						
						|  | return translation | 
					
						
						|  |  | 
					
						
						|  | def process_audio(self, audio_path, translate=False): | 
					
						
						|  | """Main processing function""" | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | waveform, sample_rate = torchaudio.load(audio_path) | 
					
						
						|  | if waveform.shape[0] > 1: | 
					
						
						|  | waveform = torch.mean(waveform, dim=0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if sample_rate != self.sample_rate: | 
					
						
						|  | waveform = torchaudio.transforms.Resample(sample_rate, self.sample_rate)(waveform) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | models = self.load_models() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | chunk_samples = int(self.chunk_size * self.sample_rate) | 
					
						
						|  | overlap_samples = int(self.overlap * self.sample_rate) | 
					
						
						|  |  | 
					
						
						|  | segments = [] | 
					
						
						|  | language_segments = [] | 
					
						
						|  |  | 
					
						
						|  | for i in range(0, len(waveform), chunk_samples - overlap_samples): | 
					
						
						|  | chunk = waveform[i:i + chunk_samples] | 
					
						
						|  | if len(chunk) < chunk_samples: | 
					
						
						|  | chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | start_time = i / self.sample_rate | 
					
						
						|  | end_time = (i + len(chunk)) / self.sample_rate | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | language = self.identify_language(chunk, models) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | language_segments.append({ | 
					
						
						|  | "language": language, | 
					
						
						|  | "start": start_time, | 
					
						
						|  | "end": end_time | 
					
						
						|  | }) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | transcription = self.transcribe_chunk(chunk, language, models) | 
					
						
						|  |  | 
					
						
						|  | segment = { | 
					
						
						|  | "start": start_time, | 
					
						
						|  | "end": end_time, | 
					
						
						|  | "language": language, | 
					
						
						|  | "text": transcription, | 
					
						
						|  | "speaker": "Speaker" | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | if translate: | 
					
						
						|  | translation = self.translate_text(transcription, models) | 
					
						
						|  | segment["translated"] = translation | 
					
						
						|  |  | 
					
						
						|  | segments.append(segment) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | torch.cuda.empty_cache() | 
					
						
						|  | gc.collect() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | merged_segments = self.merge_segments(segments) | 
					
						
						|  |  | 
					
						
						|  | return language_segments, merged_segments | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Error processing audio: {str(e)}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  | def merge_segments(self, segments, time_threshold=0.5, similarity_threshold=0.7): | 
					
						
						|  | """Merge similar nearby segments""" | 
					
						
						|  | if not segments: | 
					
						
						|  | return segments | 
					
						
						|  |  | 
					
						
						|  | merged = [] | 
					
						
						|  | current = segments[0] | 
					
						
						|  |  | 
					
						
						|  | for next_segment in segments[1:]: | 
					
						
						|  | if (next_segment['start'] - current['end'] <= time_threshold and | 
					
						
						|  | current['language'] == next_segment['language']): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | matcher = SequenceMatcher(None, current['text'], next_segment['text']) | 
					
						
						|  | similarity = matcher.ratio() | 
					
						
						|  |  | 
					
						
						|  | if similarity > similarity_threshold: | 
					
						
						|  |  | 
					
						
						|  | current['end'] = next_segment['end'] | 
					
						
						|  | current['text'] = current['text'] + ' ' + next_segment['text'] | 
					
						
						|  | if 'translated' in current and 'translated' in next_segment: | 
					
						
						|  | current['translated'] = current['translated'] + ' ' + next_segment['translated'] | 
					
						
						|  | else: | 
					
						
						|  | merged.append(current) | 
					
						
						|  | current = next_segment | 
					
						
						|  | else: | 
					
						
						|  | merged.append(current) | 
					
						
						|  | current = next_segment | 
					
						
						|  |  | 
					
						
						|  | merged.append(current) | 
					
						
						|  | return merged |