import streamlit as st
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import torch
import json
import os
import glob
from pathlib import Path
from datetime import datetime
import edge_tts
import asyncio
import base64
import requests
from collections import defaultdict
from audio_recorder_streamlit import audio_recorder
import streamlit.components.v1 as components
from urllib.parse import quote
from xml.etree import ElementTree as ET

# Initialize session state
if 'search_history' not in st.session_state:
    st.session_state['search_history'] = []
if 'last_voice_input' not in st.session_state:
    st.session_state['last_voice_input'] = ""
if 'transcript_history' not in st.session_state:
    st.session_state['transcript_history'] = []
if 'should_rerun' not in st.session_state:
    st.session_state['should_rerun'] = False
if 'search_columns' not in st.session_state:
    st.session_state['search_columns'] = []
if 'initial_search_done' not in st.session_state:
    st.session_state['initial_search_done'] = False
if 'tts_voice' not in st.session_state:
    st.session_state['tts_voice'] = "en-US-AriaNeural"
if 'arxiv_last_query' not in st.session_state:
    st.session_state['arxiv_last_query'] = ""

def fetch_dataset_info(dataset_id):
    """Fetch dataset information including all available configs and splits"""
    info_url = f"https://huggingface.co/api/datasets/{dataset_id}"
    try:
        response = requests.get(info_url, timeout=30)
        if response.status_code == 200:
            return response.json()
    except Exception as e:
        st.warning(f"Error fetching dataset info: {e}")
    return None

def fetch_dataset_rows(dataset_id, config="default", split="train", max_rows=100):
    """Fetch rows from a specific config and split of a dataset"""
    url = f"https://datasets-server.huggingface.co/first-rows?dataset={dataset_id}&config={config}&split={split}"
    try:
        response = requests.get(url, timeout=30)
        if response.status_code == 200:
            data = response.json()
            if 'rows' in data:
                processed_rows = []
                for row_data in data['rows']:
                    row = row_data.get('row', row_data)
                    # Process embeddings if present
                    for key in row:
                        if any(term in key.lower() for term in ['embed', 'vector', 'encoding']):
                            if isinstance(row[key], str):
                                try:
                                    row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()]
                                except:
                                    continue
                    row['_config'] = config
                    row['_split'] = split
                    processed_rows.append(row)
                return processed_rows
    except Exception as e:
        st.warning(f"Error fetching rows for {config}/{split}: {e}")
    return []

def search_dataset(dataset_id, search_text, include_configs=None, include_splits=None):
    """
    Search across all configurations and splits of a dataset
    
    Args:
        dataset_id (str): The Hugging Face dataset ID
        search_text (str): Text to search for in descriptions and queries
        include_configs (list): List of specific configs to search, or None for all
        include_splits (list): List of specific splits to search, or None for all
    
    Returns:
        tuple: (DataFrame of results, list of available configs, list of available splits)
    """
    # Get dataset info
    dataset_info = fetch_dataset_info(dataset_id)
    if not dataset_info:
        return pd.DataFrame(), [], []
    
    # Get available configs and splits
    configs = include_configs if include_configs else dataset_info.get('config_names', ['default'])
    all_rows = []
    available_splits = set()
    
    # Search across configs and splits
    for config in configs:
        try:
            # First fetch split info for this config
            splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}"
            splits_response = requests.get(splits_url, timeout=30)
            if splits_response.status_code == 200:
                splits_data = splits_response.json()
                splits = [split['split'] for split in splits_data.get('splits', [])]
                if not splits:
                    splits = ['train']  # fallback to train if no splits found
                
                # Filter splits if specified
                if include_splits:
                    splits = [s for s in splits if s in include_splits]
                
                available_splits.update(splits)
                
                # Fetch and search rows for each split
                for split in splits:
                    rows = fetch_dataset_rows(dataset_id, config, split)
                    for row in rows:
                        # Search in all text fields
                        text_content = ' '.join(str(v) for v in row.values() if isinstance(v, (str, int, float)))
                        if search_text.lower() in text_content.lower():
                            row['_matched_text'] = text_content
                            row['_relevance_score'] = text_content.lower().count(search_text.lower())
                            all_rows.append(row)
        
        except Exception as e:
            st.warning(f"Error processing config {config}: {e}")
            continue
    
    # Convert to DataFrame and sort by relevance
    if all_rows:
        df = pd.DataFrame(all_rows)
        df = df.sort_values('_relevance_score', ascending=False)
        return df, configs, list(available_splits)
    
    return pd.DataFrame(), configs, list(available_splits)

class VideoSearch:
    def __init__(self):
        self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
        self.dataset_id = "omegalabsinc/omega-multimodal"
        self.load_dataset()
        
    def fetch_dataset_rows(self):
        """Fetch dataset with enhanced search capabilities"""
        try:
            # First try to get all available data
            df, configs, splits = search_dataset(
                self.dataset_id,
                "",  # empty search text to get all data
                include_configs=None,  # all configs
                include_splits=None    # all splits
            )
            
            if not df.empty:
                st.session_state['search_columns'] = [col for col in df.columns 
                    if col not in ['video_embed', 'description_embed', 'audio_embed']
                    and not col.startswith('_')]
                return df
                
            return self.load_example_data()
            
        except Exception as e:
            st.warning(f"Error loading dataset: {e}")
            return self.load_example_data()

    def load_example_data(self):
        """Load example data as fallback"""
        example_data = [
            {
                "video_id": "cd21da96-fcca-4c94-a60f-0b1e4e1e29fc",
                "youtube_id": "IO-vwtyicn4",
                "description": "This video shows a close-up of an ancient text carved into a surface.",
                "views": 45489,
                "start_time": 1452,
                "end_time": 1458,
                "video_embed": [0.014160037972033024, -0.003111184574663639, -0.016604168340563774],
                "description_embed": [-0.05835828185081482, 0.02589797042310238, 0.11952091753482819]
            }
        ]
        return pd.DataFrame(example_data)

    def prepare_features(self):
        """Prepare embeddings with adaptive field detection"""
        try:
            embed_cols = [col for col in self.dataset.columns 
                         if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])]
            
            embeddings = {}
            for col in embed_cols:
                try:
                    data = []
                    for row in self.dataset[col]:
                        if isinstance(row, str):
                            values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()]
                        elif isinstance(row, list):
                            values = row
                        else:
                            continue
                        data.append(values)
                    
                    if data:
                        embeddings[col] = np.array(data)
                except:
                    continue
            
            # Set main embeddings for search
            if 'video_embed' in embeddings:
                self.video_embeds = embeddings['video_embed']
            else:
                self.video_embeds = next(iter(embeddings.values()))
                
            if 'description_embed' in embeddings:
                self.text_embeds = embeddings['description_embed']
            else:
                self.text_embeds = self.video_embeds
                
        except:
            # Fallback to random embeddings
            num_rows = len(self.dataset)
            self.video_embeds = np.random.randn(num_rows, 384)
            self.text_embeds = np.random.randn(num_rows, 384)
    
    def load_dataset(self):
        self.dataset = self.fetch_dataset_rows()
        self.prepare_features()

    def search(self, query, column=None, top_k=20):
        query_embedding = self.text_model.encode([query])[0]
        video_sims = cosine_similarity([query_embedding], self.video_embeds)[0]
        text_sims = cosine_similarity([query_embedding], self.text_embeds)[0]
        combined_sims = 0.5 * video_sims + 0.5 * text_sims
        
        # Column filtering
        if column and column in self.dataset.columns and column != "All Fields":
            mask = self.dataset[column].astype(str).str.contains(query, case=False)
            combined_sims[~mask] *= 0.5
        
        top_k = min(top_k, 100)
        top_indices = np.argsort(combined_sims)[-top_k:][::-1]
        
        results = []
        for idx in top_indices:
            result = {'relevance_score': float(combined_sims[idx])}
            for col in self.dataset.columns:
                if col not in ['video_embed', 'description_embed', 'audio_embed']:
                    result[col] = self.dataset.iloc[idx][col]
            results.append(result)
        
        return results

@st.cache_resource
def get_speech_model():
    return edge_tts.Communicate

async def generate_speech(text, voice=None):
    if not text.strip():
        return None
    if not voice:
        voice = st.session_state['tts_voice']
    try:
        communicate = get_speech_model()(text, voice)
        audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
        await communicate.save(audio_file)
        return audio_file
    except Exception as e:
        st.error(f"Error generating speech: {e}")
        return None

def transcribe_audio(audio_path):
    """Placeholder for ASR transcription"""
    return "ASR not implemented. Integrate a local model or another service here."

def show_file_manager():
    """Display file manager interface"""
    st.subheader("📂 File Manager")
    col1, col2 = st.columns(2)
    with col1:
        uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3'])
        if uploaded_file:
            with open(uploaded_file.name, "wb") as f:
                f.write(uploaded_file.getvalue())
            st.success(f"Uploaded: {uploaded_file.name}")
            st.experimental_rerun()
    
    with col2:
        if st.button("🗑 Clear All Files"):
            for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"):
                os.remove(f)
            st.success("All files cleared!")
            st.experimental_rerun()
    
    files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3")
    if files:
        st.write("### Existing Files")
        for f in files:
            with st.expander(f"📄 {os.path.basename(f)}"):
                if f.endswith('.mp3'):
                    st.audio(f)
                else:
                    with open(f, 'r', encoding='utf-8') as file:
                        st.text_area("Content", file.read(), height=100)
                if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"):
                    os.remove(f)
                    st.experimental_rerun()

def arxiv_search(query, max_results=5):
    """Perform a simple Arxiv search using their API and return top results."""
    base_url = "http://export.arxiv.org/api/query?"
    search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}"
    r = requests.get(search_url)
    if r.status_code == 200:
        root = ET.fromstring(r.text)
        ns = {'atom': 'http://www.w3.org/2005/Atom'}
        entries = root.findall('atom:entry', ns)
        results = []
        for entry in entries:
            title = entry.find('atom:title', ns).text.strip()
            summary = entry.find('atom:summary', ns).text.strip()
            link = None
            for l in entry.findall('atom:link', ns):
                if l.get('type') == 'text/html':
                    link = l.get('href')
                    break
            results.append((title, summary, link))
        return results
    return []

def perform_arxiv_lookup(q, vocal_summary=True, titles_summary=True, full_audio=False):
    results = arxiv_search(q, max_results=5)
    if not results:
        st.write("No Arxiv results found.")
        return
    st.markdown(f"**Arxiv Search Results for '{q}':**")
    for i, (title, summary, link) in enumerate(results, start=1):
        st.markdown(f"**{i}. {title}**")
        st.write(summary)
        if link:
            st.markdown(f"[View Paper]({link})")

    if vocal_summary:
        spoken_text = f"Here are some Arxiv results for {q}. "
        if titles_summary:
            spoken_text += " Titles: " + ", ".join([res[0] for res in results])
        else:
            # Just first summary if no titles_summary
            spoken_text += " " + results[0][1][:200]

        audio_file = asyncio.run(generate_speech(spoken_text))
        if audio_file:
            st.audio(audio_file)
    
    if full_audio:
        # Full audio of summaries
        full_text = ""
        for i,(title, summary, _) in enumerate(results, start=1):
            full_text += f"Result {i}: {title}. {summary} "
        audio_file_full = asyncio.run(generate_speech(full_text))
        if audio_file_full:
            st.write("### Full Audio")
            st.audio(audio_file_full)

def main():
    st.title("🎥 Video & Arxiv Search with Voice (No OpenAI/Anthropic)")
    
    # Initialize search class
    search = VideoSearch()
    
    # Create tabs
    tab1, tab2, tab3, tab4, tab5 = st.tabs(["🔍 Search", "🎙️ Voice Input", "📚 Arxiv", "📂 Files", "🔍 Advanced Search"])
    
    # ---- Tab 1: Video Search ----
    with tab1:
        st.subheader("Search Videos")
        col1, col2 = st.columns([3, 1])
        with col1:
            query = st.text_input("Enter your search query:", 
                                  value="ancient" if not st.session_state['initial_search_done'] else "")
        with col2:
            search_column = st.selectbox("Search in field:", 
                                       ["All Fields"] + st.session_state['search_columns'])
        
        col3, col4 = st.columns(2)
        with col3:
            num_results = st.slider("Number of results:", 1, 100, 20)
        with col4:
            search_button = st.button("🔍 Search")
        
        if (search_button or not st.session_state['initial_search_done']) and query:
            st.session_state['initial_search_done'] = True
            selected_column = None if search_column == "All Fields" else search_column
            with st.spinner("Searching..."):
                results = search.search(query, selected_column, num_results)
            
            st.session_state['search_history'].append({
                'query': query,
                'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                'results': results[:5]
            })
            
            for i, result in enumerate(results, 1):
                with st.expander(f"Result {i}: {result['description'][:100]}...", expanded=(i==1)):
                    cols = st.columns([2, 1])
                    with cols[0]:
                        st.markdown("**Description:**")
                        st.write(result['description'])
                        st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s")
                        st.markdown(f"**Views:** {result['views']:,}")
                    
                    with cols[1]:
                        st.markdown(f"**Relevance Score:** {result['relevance_score']:.2%}")
                        if result.get('youtube_id'):
                            st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result['start_time']}")
                        
                        if st.button(f"🔊 Audio Summary", key=f"audio_{i}"):
                            summary = f"Video summary: {result['description'][:200]}"
                            audio_file = asyncio.run(generate_speech(summary))
                            if audio_file:
                                st.audio(audio_file)

    # ---- Tab 2: Voice Input ----
    with tab2:
        st.subheader("Voice Input")
        st.write("🎙️ Record your voice:")
        audio_bytes = audio_recorder()
        if audio_bytes:
            audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
            with open(audio_path, "wb") as f:
                f.write(audio_bytes)
            st.success("Audio recorded successfully!")
            
            voice_query = transcribe_audio(audio_path)
            st.markdown("**Transcribed Text:**")
            st.write(voice_query)
            st.session_state['last_voice_input'] = voice_query
            
            if st.button("🔍 Search from Voice"):
                results = search.search(voice_query, None, 20)
                for i, result in enumerate(results, 1):
                    with st.expander(f"Result {i}", expanded=(i==1)):
                        st.write(result['description'])
                        if result.get('youtube_id'):
                            st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
            
            if os.path.exists(audio_path):
                os.remove(audio_path)

    # ---- Tab 3: Arxiv Search ----
    with tab3:
        st.subheader("Arxiv Search")
        q = st.text_input("Enter your Arxiv search query:", value=st.session_state['arxiv_last_query'])
        vocal_summary = st.checkbox("🎙 Short Audio Summary", value=True)
        titles_summary = st.checkbox("🔖 Titles Only", value=True)
        full_audio = st.checkbox("📚 Full Audio Results", value=False)
        
        if st.button("🔍 Arxiv Search"):
            st.session_state['arxiv_last_query'] = q
            perform_arxiv_lookup(q, vocal_summary=vocal_summary, titles_summary=titles_summary, full_audio=full_audio)

    # ---- Tab 4: File Manager ----
    with tab4:
        show_file_manager()

    # ---- Tab 5: Advanced Dataset Search ----
    with tab5:
        st.subheader("Advanced Dataset Search")
        
        # Dataset input
        dataset_id = st.text_input("Dataset ID:", value="omegalabsinc/omega-multimodal")
        
        # Search configuration
        col1, col2 = st.columns([2, 1])
        with col1:
            search_text = st.text_input("Search text:", 
                placeholder="Enter text to search across all fields")
        
        # Get available configs and splits
        if dataset_id:
            dataset_info = fetch_dataset_info(dataset_id)
            if dataset_info:
                configs = dataset_info.get('config_names', ['default'])
                with col2:
                    selected_configs = st.multiselect(
                        "Configurations:",
                        options=configs,
                        default=['default'] if 'default' in configs else None
                    )
                
                # Fetch available splits
                if selected_configs:
                    all_splits = set()
                    for config in selected_configs:
                        splits_url = f"https://datasets-server.huggingface.co/splits?dataset={dataset_id}&config={config}"
                        try:
                            response = requests.get(splits_url, timeout=30)
                            if response.status_code == 200:
                                splits_data = response.json()
                                splits = [split['split'] for split in splits_data.get('splits', [])]
                                all_splits.update(splits)
                        except Exception as e:
                            st.warning(f"Error fetching splits for {config}: {e}")
                    
                    selected_splits = st.multiselect(
                        "Splits:",
                        options=list(all_splits),
                        default=['train'] if 'train' in all_splits else None
                    )
                    
                    if st.button("🔍 Search Dataset"):
                        with st.spinner("Searching dataset..."):
                            results_df, _, _ = search_dataset(
                                dataset_id,
                                search_text,
                                include_configs=selected_configs,
                                include_splits=selected_splits
                            )
                            
                            if not results_df.empty:
                                st.write(f"Found {len(results_df)} results")
                                
                                # Display results in expandable sections
                                for idx, row in results_df.iterrows():
                                    with st.expander(
                                        f"Result {idx+1} (Config: {row['_config']}, Split: {row['_split']}, Score: {row['_relevance_score']})"
                                    ):
                                        # Display all fields except internal ones
                                        for col in row.index:
                                            if not col.startswith('_') and not any(
                                                term in col.lower() 
                                                for term in ['embed', 'vector', 'encoding']
                                            ):
                                                st.write(f"**{col}:** {row[col]}")
                                        
                                        # Add buttons for audio/video if available
                                        if 'youtube_id' in row:
                                            st.video(
                                                f"https://youtube.com/watch?v={row['youtube_id']}&t={row.get('start_time', 0)}"
                                            )
                            else:
                                st.warning("No results found.")
            else:
                st.error("Unable to fetch dataset information.")

    # Sidebar
    with st.sidebar:
        st.subheader("⚙️ Settings & History")
        if st.button("🗑️ Clear History"):
            st.session_state['search_history'] = []
            st.experimental_rerun()
        
        st.markdown("### Recent Searches")
        for entry in reversed(st.session_state['search_history'][-5:]):
            with st.expander(f"{entry['timestamp']}: {entry['query']}"):
                for i, result in enumerate(entry['results'], 1):
                    st.write(f"{i}. {result['description'][:100]}...")

        st.markdown("### Voice Settings")
        st.selectbox("TTS Voice:", 
                     ["en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural"],
                     key="tts_voice")

if __name__ == "__main__":
    main()