"""
Multi-Speaker Audio Analyzer
A Streamlit application that performs speaker diarization, transcription, and summarization on audio files.

Author: [Your Name]
Date: January 2025
"""

import streamlit as st
from src.models.diarization import SpeakerDiarizer
from src.models.transcription import Transcriber
from src.models.summarization import Summarizer
from src.utils.audio_processor import AudioProcessor
from src.utils.formatter import TimeFormatter
import os

# Cache for model loading
@st.cache_resource
def load_models():
    """
    Load and cache all required models.
    
    Returns:
        tuple: (diarizer, transcriber, summarizer) or (None, None, None) if loading fails
    """
    try:
        diarizer = SpeakerDiarizer(st.secrets["hf_token"])
        diarizer_model = diarizer.load_model()
        
        transcriber = Transcriber()
        transcriber_model = transcriber.load_model()
        
        summarizer = Summarizer()
        summarizer_model = summarizer.load_model()
        
        if not all([diarizer_model, transcriber_model, summarizer_model]):
            raise ValueError("One or more models failed to load")
            
        return diarizer, transcriber, summarizer
    except Exception as e:
        st.error(f"Error loading models: {str(e)}")
        st.error("Debug info: Check if HF token is valid and has necessary permissions")
        return None, None, None

def process_audio(audio_file, max_duration=600):
    """
    Process the uploaded audio file through all models.
    
    Args:
        audio_file: Uploaded audio file
        max_duration (int): Maximum duration in seconds
        
    Returns:
        dict: Processing results containing diarization, transcription, and summary
    """
    try:
        # Process audio file
        audio_processor = AudioProcessor()
        tmp_path = audio_processor.standardize_audio(audio_file)
        
        # Load models
        diarizer, transcriber, summarizer = load_models()
        if not all([diarizer, transcriber, summarizer]):
            return "Model loading failed"

        # Process with each model
        with st.spinner("Identifying speakers..."):
            diarization_result = diarizer.process(tmp_path)
        
        with st.spinner("Transcribing audio..."):
            transcription = transcriber.process(tmp_path)
            
        with st.spinner("Generating summary..."):
            summary = summarizer.process(transcription["text"])

        # Cleanup
        os.unlink(tmp_path)
        
        return {
            "diarization": diarization_result,
            "transcription": transcription,
            "summary": summary[0]["summary_text"]
        }
        
    except Exception as e:
        st.error(f"Error processing audio: {str(e)}")
        return None

def main():
    """Main application function."""
    st.title("Multi-Speaker Audio Analyzer")
    st.write("Upload an audio file (MP3/WAV) up to 5 minutes long for best performance")

    uploaded_file = st.file_uploader("Choose a file", type=["mp3", "wav"])

    if uploaded_file:
        file_size = len(uploaded_file.getvalue()) / (1024 * 1024)
        st.write(f"File size: {file_size:.2f} MB")
        
        st.audio(uploaded_file, format='audio/wav')
        
        if st.button("Analyze Audio"):
            if file_size > 200:
                st.error("File size exceeds 200MB limit")
            else:
                results = process_audio(uploaded_file)
                
                if results:
                    tab1, tab2, tab3 = st.tabs(["Speakers", "Transcription", "Summary"])
                    
                    # Display speaker timeline
                    with tab1:
                        display_speaker_timeline(results)
                    
                    # Display transcription
                    with tab2:
                        display_transcription(results)
                    
                    # Display summary
                    with tab3:
                        display_summary(results)

def display_speaker_timeline(results):
    """Display speaker diarization results in a timeline format."""
    st.write("Speaker Timeline:")
    segments = TimeFormatter.format_speaker_segments(
        results["diarization"], 
        results["transcription"]
    )
    
    if segments:
        for segment in segments:
            col1, col2, col3 = st.columns([2,3,5])
            
            with col1:
                display_speaker_info(segment)
            
            with col2:
                display_timestamp(segment)
            
            with col3:
                display_text(segment)
            
            st.markdown("---")
    else:
        st.warning("No speaker segments detected")

def display_speaker_info(segment):
    """Display speaker information with color coding."""
    speaker_num = int(segment['speaker'].split('_')[1])
    colors = ['🔵', '🔴']
    speaker_color = colors[speaker_num % len(colors)]
    st.write(f"{speaker_color} {segment['speaker']}")

def display_timestamp(segment):
    """Display formatted timestamps."""
    start_time = TimeFormatter.format_timestamp(segment['start'])
    end_time = TimeFormatter.format_timestamp(segment['end'])
    st.write(f"{start_time} → {end_time}")

def display_text(segment):
    """Display speaker's text."""
    if segment['text']:
        st.write(f"\"{segment['text']}\"")
    else:
        st.write("(no speech detected)")

if __name__ == "__main__":
    main()