import gc
import torch
import torchaudio
import numpy as np
from transformers import (
    Wav2Vec2ForSequenceClassification, 
    AutoFeatureExtractor, 
    Wav2Vec2ForCTC, 
    AutoProcessor, 
    AutoTokenizer, 
    AutoModelForSeq2SeqLM
)
import spaces
import logging
from difflib import SequenceMatcher

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)



class AudioProcessor:
    def __init__(self, chunk_size=5, overlap=1, sample_rate=16000):
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.sample_rate = sample_rate
        self.previous_text = ""
        self.previous_lang = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def load_models(self):
        """Load all required models"""
        logger.info("Loading MMS models...")
        
        # Language identification model
        lid_processor = AutoFeatureExtractor.from_pretrained("facebook/mms-lid-256")
        lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-256")
        
        # Transcription model
        mms_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
        mms_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all")
        
        # Translation model
        translation_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
        translation_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
        
        return {
            'lid': (lid_model, lid_processor),
            'mms': (mms_model, mms_processor),
            'translation': (translation_model, translation_tokenizer)
        }

    @spaces.GPU(duration=60)
    def identify_language(self, audio_chunk, models):
        """Identify language of audio chunk"""
        lid_model, lid_processor = models['lid']
        inputs = lid_processor(audio_chunk, sampling_rate=16000, return_tensors="pt")
        lid_model.to(self.device)
        with torch.no_grad():
            outputs = lid_model(inputs.input_values.to(self.device)).logits
            lang_id = torch.argmax(outputs, dim=-1)[0].item()
            detected_lang = lid_model.config.id2label[lang_id]
        
        return detected_lang

    @spaces.GPU(duration=60)
    def transcribe_chunk(self, audio_chunk, language, models):
        """Transcribe audio chunk"""
        mms_model, mms_processor = models['mms']
        
        mms_processor.tokenizer.set_target_lang(language)
        mms_model.load_adapter(language)
        mms_model.to(self.device)
        inputs = mms_processor(audio_chunk, sampling_rate=16000, return_tensors="pt")
        
        with torch.no_grad():
            outputs = mms_model(inputs.input_values.to(self.device)).logits
            ids = torch.argmax(outputs, dim=-1)[0]
            transcription = mms_processor.decode(ids)
        
        return transcription

    @spaces.GPU(duration=60)
    def translate_text(self, text, models):
        """Translate text to English"""
        translation_model, translation_tokenizer = models['translation']
        
        inputs = translation_tokenizer(text, return_tensors="pt")
        inputs = inputs.to(self.device)
        translation_model.to(self.device)
        with torch.no_grad():
            outputs = translation_model.generate(
                **inputs,
                forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"),
                max_length=100
            )
            translation = translation_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
        
        return translation

    def preprocess_audio(self, audio):
        """
        Create overlapping chunks with improved timing logic
        """
        chunk_samples = int(self.chunk_size * self.sample_rate)
        overlap_samples = int(self.overlap * self.sample_rate)

        chunks_with_times = []
        start_idx = 0

        while start_idx < len(audio):
            end_idx = min(start_idx + chunk_samples, len(audio))

            # Add padding for first chunk
            if start_idx == 0:
                chunk = audio[start_idx:end_idx]
                padding = torch.zeros(int(1 * self.sample_rate))
                chunk = torch.cat([padding, chunk])
            else:
                # Include overlap from previous chunk
                actual_start = max(0, start_idx - overlap_samples)
                chunk = audio[actual_start:end_idx]

            # Pad if necessary
            if len(chunk) < chunk_samples:
                chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))

            # Adjust time ranges to account for overlaps
            chunk_start_time = max(0, (start_idx / self.sample_rate) - self.overlap)
            chunk_end_time = min((end_idx / self.sample_rate) + self.overlap, len(audio) / self.sample_rate)

            chunks_with_times.append({
                'chunk': chunk,
                'start_time': start_idx / self.sample_rate,
                'end_time': end_idx / self.sample_rate,
                'transcribe_start': chunk_start_time,
                'transcribe_end': chunk_end_time
            })

            # Move to next chunk with smaller step size for better continuity
            start_idx += (chunk_samples - overlap_samples)

        return chunks_with_times

    
    @spaces.GPU(duration=60)
    def process_audio(self, audio_path, translate=False):
        """Main processing function"""
        try:
            # Load audio
            waveform, sample_rate = torchaudio.load(audio_path)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0)
            else: 
                waveform = waveform.squeeze(0)
                
            # Resample if necessary
            if sample_rate != self.sample_rate:
                resampler = torchaudio.transforms.Resample(
                    orig_freq=sample_rate,
                    new_freq=self.sample_rate
                )
                waveform = resampler(waveform)

            # if sample_rate != self.sample_rate:
            #     waveform = torchaudio.transforms.Resample(sample_rate, self.sample_rate)(waveform)

            # Load models
            models = self.load_models()
            
            # Process in chunks
            chunk_samples = int(self.chunk_size * self.sample_rate)
            overlap_samples = int(self.overlap * self.sample_rate)
            
            segments = []
            language_segments = []
            
            for i in range(0, len(waveform), chunk_samples - overlap_samples):
                chunk = waveform[i:i + chunk_samples]
                if len(chunk) < chunk_samples:
                    chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - len(chunk)))
                
                # Process chunk
                start_time = i / self.sample_rate
                end_time = (i + len(chunk)) / self.sample_rate
                
                # Identify language
                language = self.identify_language(chunk, models)
                
                # Record language segment
                language_segments.append({
                    "language": language,
                    "start": start_time,
                    "end": end_time
                })
                
                # Transcribe
                transcription = self.transcribe_chunk(chunk, language, models)
                
                segment = {
                    "start": start_time,
                    "end": end_time,
                    "language": language,
                    "text": transcription,
                    "speaker": "Speaker"  # Simple speaker assignment
                }
                
                if translate:
                    translation = self.translate_text(transcription, models)
                    segment["translated"] = translation
                
                segments.append(segment)
                
                # Clean up GPU memory
                torch.cuda.empty_cache()
                gc.collect()
            
            # Merge nearby segments
            merged_segments = self.merge_segments(segments)
            
            return language_segments, merged_segments

        except Exception as e:
            logger.error(f"Error processing audio: {str(e)}")
            raise

    
    def merge_segments(self, segments, time_threshold=0.5, similarity_threshold=0.7):
        """Merge similar nearby segments"""
        if not segments:
            return segments
            
        merged = []
        current = segments[0]
        
        for next_segment in segments[1:]:
            if (next_segment['start'] - current['end'] <= time_threshold and 
                current['language'] == next_segment['language']):
                
                # Check text similarity
                matcher = SequenceMatcher(None, current['text'], next_segment['text'])
                similarity = matcher.ratio()
                
                if similarity > similarity_threshold:
                    # Merge segments
                    current['end'] = next_segment['end']
                    current['text'] = current['text'] + ' ' + next_segment['text']
                    if 'translated' in current and 'translated' in next_segment:
                        current['translated'] = current['translated'] + ' ' + next_segment['translated']
                else:
                    merged.append(current)
                    current = next_segment
            else:
                merged.append(current)
                current = next_segment
        
        merged.append(current)
        return merged