import streamlit as st
import pandas as pd
import numpy as np
import json
import os
import glob
import random
from pathlib import Path
from datetime import datetime
import edge_tts
import asyncio
import requests
import streamlit.components.v1 as components
import base64
import re
from xml.etree import ElementTree as ET
from datasets import load_dataset

# -------------------- Configuration & Constants --------------------
USER_NAMES = [
    "Aria", "Guy", "Sonia", "Tony", "Jenny", "Davis", "Libby", "Clara", "Liam", "Natasha", "William"
]

ENGLISH_VOICES = [
    "en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural", "en-GB-TonyNeural",
    "en-US-JennyNeural", "en-US-DavisNeural", "en-GB-LibbyNeural", "en-CA-ClaraNeural",
    "en-CA-LiamNeural", "en-AU-NatashaNeural", "en-AU-WilliamNeural"
]

# Map each user to a corresponding voice
USER_VOICES = dict(zip(USER_NAMES, ENGLISH_VOICES))

ROWS_PER_PAGE = 100
SAVED_INPUTS_DIR = "saved_inputs"
os.makedirs(SAVED_INPUTS_DIR, exist_ok=True)

SESSION_VARS = {
    'search_history': [],
    'last_voice_input': "",
    'transcript_history': [],
    'should_rerun': False,
    'search_columns': [],
    'initial_search_done': False,
    'arxiv_last_query': "",
    'dataset_loaded': False,
    'current_page': 0,
    'data_cache': None,
    'dataset_info': None,
    'nps_submitted': False,
    'nps_last_shown': None,
    'old_val': None,
    'voice_text': None,
    'user_name': random.choice(USER_NAMES),
    'max_items': 100,
    'global_voice': "en-US-AriaNeural",
    'last_arxiv_input': None
}

for var, default in SESSION_VARS.items():
    if var not in st.session_state:
        st.session_state[var] = default

def create_voice_component():
    mycomponent = components.declare_component(
        "mycomponent",
        path="mycomponent"
    )
    return mycomponent

def clean_for_speech(text: str) -> str:
    text = text.replace("\n", " ")
    text = text.replace("</s>", " ")
    text = text.replace("#", "")
    text = re.sub(r"\(https?:\/\/[^\)]+\)", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

async def edge_tts_generate_audio(text, voice="en-US-AriaNeural"):
    text = clean_for_speech(text)
    if not text.strip():
        return None
    communicate = edge_tts.Communicate(text, voice)
    out_fn = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}.mp3"
    await communicate.save(out_fn)
    return out_fn

def speak_with_edge_tts(text, voice="en-US-AriaNeural"):
    return asyncio.run(edge_tts_generate_audio(text, voice))

def play_and_download_audio(file_path):
    if file_path and os.path.exists(file_path):
        st.audio(file_path)
        dl_link = f'<a href="data:audio/mpeg;base64,{base64.b64encode(open(file_path,"rb").read()).decode()}" download="{os.path.basename(file_path)}">Download {os.path.basename(file_path)}</a>'
        st.markdown(dl_link, unsafe_allow_html=True)

def generate_filename(prefix, text):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    safe_text = re.sub(r'[^\w\s-]', '', text[:50]).strip().lower()
    safe_text = re.sub(r'[-\s]+', '-', safe_text)
    return f"{prefix}_{timestamp}_{safe_text}.md"

def save_input_as_md(user_name, text, prefix="input"):
    if not text.strip():
        return
    fn = generate_filename(prefix, text)
    full_path = os.path.join(SAVED_INPUTS_DIR, fn)
    with open(full_path, 'w', encoding='utf-8') as f:
        f.write(f"# User: {user_name}\n")
        f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        f.write(text)
    return full_path

def save_response_as_md(user_name, text, prefix="response"):
    if not text.strip():
        return
    fn = generate_filename(prefix, text)
    full_path = os.path.join(SAVED_INPUTS_DIR, fn)
    with open(full_path, 'w', encoding='utf-8') as f:
        f.write(f"# User: {user_name}\n")
        f.write(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        f.write(text)
    return full_path

def list_saved_inputs():
    files = sorted(glob.glob(os.path.join(SAVED_INPUTS_DIR, "*.md")))
    return files

def parse_md_file(fpath):
    user_line = ""
    ts_line = ""
    content_lines = []
    with open(fpath, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    for line in lines:
        if line.startswith("# User:"):
            user_line = line.replace("# User:", "").strip()
        elif line.startswith("**Timestamp:**"):
            ts_line = line.replace("**Timestamp:**", "").strip()
        else:
            content_lines.append(line.strip())
    content = "\n".join(content_lines).strip()
    return user_line, ts_line, content

def arxiv_search(query, max_results=3):
    base_url = "http://export.arxiv.org/api/query"
    params = {
        'search_query': query.replace(' ', '+'),
        'start': 0,
        'max_results': max_results
    }
    response = requests.get(base_url, params=params, timeout=30)
    if response.status_code == 200:
        root = ET.fromstring(response.text)
        ns = {"a": "http://www.w3.org/2005/Atom"}
        entries = root.findall('a:entry', ns)
        results = []
        for entry in entries:
            title = entry.find('a:title', ns).text.strip()
            summary = entry.find('a:summary', ns).text.strip()
            summary_short = summary[:300] + "..."
            results.append((title, summary_short))
        return results
    return []

def summarize_arxiv_results(results):
    lines = []
    for i, (title, summary) in enumerate(results, 1):
        lines.append(f"Result {i}: {title}\n{summary}\n")
    return "\n\n".join(lines)

def simple_dataset_search(query, df):
    if df.empty or not query.strip():
        return pd.DataFrame()
    query_terms = query.lower().split()
    matches = []
    for idx, row in df.iterrows():
        text_parts = []
        for col in df.columns:
            val = row[col]
            if isinstance(val, str):
                text_parts.append(val.lower())
            elif isinstance(val, (int, float)):
                text_parts.append(str(val))
        full_text = " ".join(text_parts)
        if any(qt in full_text for qt in query_terms):
            matches.append(row)
    if matches:
        return pd.DataFrame(matches)
    return pd.DataFrame()

from datasets import load_dataset

@st.cache_data
def load_dataset_page(dataset_id, token, page, rows_per_page):
    try:
        start_idx = page * rows_per_page
        end_idx = start_idx + rows_per_page
        dataset = load_dataset(
            dataset_id,
            token=token,
            streaming=False,
            split=f'train[{start_idx}:{end_idx}]'
        )
        return pd.DataFrame(dataset)
    except:
        return pd.DataFrame()

class SimpleDatasetSearcher:
    def __init__(self, dataset_id="tomg-group-umd/cinepile"):
        self.dataset_id = dataset_id
        self.token = os.environ.get('DATASET_KEY')
    def load_page(self, page=0):
        return load_dataset_page(self.dataset_id, self.token, page, ROWS_PER_PAGE)

def concatenate_mp3(files, output_file):
    # Naive binary concatenation of MP3 files
    with open(output_file, 'wb') as outfile:
        for f in files:
            with open(f, 'rb') as infile:
                outfile.write(infile.read())

def main():
    st.title("🎙️ Voice Chat & Search")

    # Sidebar
    with st.sidebar:
        # Editable user name
        st.session_state['user_name'] = st.selectbox("Current User:", USER_NAMES, index=0)
        
        st.session_state['max_items'] = st.number_input("Max Items per search iteration:", min_value=1, max_value=1000, value=st.session_state['max_items'])
        
        st.subheader("📝 Saved Inputs & Responses")
        saved_files = list_saved_inputs()
        for fpath in saved_files:
            user, ts, content = parse_md_file(fpath)
            fname = os.path.basename(fpath)
            st.write(f"- {fname} (User: {user})")

    # Create voice component for input
    voice_component = create_voice_component()
    voice_val = voice_component(my_input_value="Start speaking...")

    # Tabs
    tab1, tab2, tab3, tab4 = st.tabs(["🗣️ Voice Chat History", "📚 ArXiv Search", "📊 Dataset Search", "⚙️ Settings"])

    # ------------------ Voice Chat History -------------------------
    with tab1:
        st.subheader("Voice Chat History")
        files = list_saved_inputs()
        conversation = []
        for fpath in files:
            user, ts, content = parse_md_file(fpath)
            conversation.append((user, ts, content, fpath))

        # Enumerate to ensure unique keys
        for i, (user, ts, content, fpath) in enumerate(reversed(conversation), start=1):
            with st.expander(f"{ts} - {user}", expanded=False):
                st.write(content)
                # Make button key unique by including i
                if st.button(f"🔊 Read Aloud {ts}-{user}", key=f"read_{i}_{fpath}"):
                    voice = USER_VOICES.get(user, "en-US-AriaNeural")
                    audio_file = speak_with_edge_tts(content, voice=voice)
                    if audio_file:
                        play_and_download_audio(audio_file)

        # Read entire conversation
        if st.button("📜 Read Conversation", key="read_conversation_all"):
            # conversation is currently reversed, re-reverse to get chronological
            conversation_chrono = list(reversed(conversation))
            mp3_files = []
            for user, ts, content, fpath in conversation_chrono:
                voice = USER_VOICES.get(user, "en-US-AriaNeural")
                audio_file = speak_with_edge_tts(content, voice=voice)
                if audio_file:
                    mp3_files.append(audio_file)
                    st.write(f"**{user} ({ts}):**")
                    play_and_download_audio(audio_file)
            
            if mp3_files:
                combined_file = f"full_conversation_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
                concatenate_mp3(mp3_files, combined_file)
                st.write("**Full Conversation Audio:**")
                play_and_download_audio(combined_file)

    # ------------------ ArXiv Search -------------------------
    with tab2:
        st.subheader("ArXiv Search")
        edited_input = st.text_area("Enter or Edit Search Query:", value=(voice_val.strip() if voice_val else ""), height=100)
        autorun = st.checkbox("⚡ Auto-Run", value=True)
        run_arxiv = st.button("🔍 ArXiv Search", key="run_arxiv_button")

        input_changed = (edited_input != st.session_state.get('old_val'))
        should_run_arxiv = False
        if autorun and input_changed and edited_input.strip():
            should_run_arxiv = True
        if run_arxiv and edited_input.strip():
            should_run_arxiv = True

        if should_run_arxiv and st.session_state['last_arxiv_input'] != edited_input:
            st.session_state['old_val'] = edited_input
            st.session_state['last_arxiv_input'] = edited_input
            save_input_as_md(st.session_state['user_name'], edited_input, prefix="input")
            with st.spinner("Searching ArXiv..."):
                results = arxiv_search(edited_input)
                if results:
                    summary = summarize_arxiv_results(results)
                    save_response_as_md(st.session_state['user_name'], summary, prefix="response")
                    st.write(summary)
                    # Play summary aloud
                    voice = USER_VOICES.get(st.session_state['user_name'], "en-US-AriaNeural")
                    audio_file = speak_with_edge_tts(summary, voice=voice)
                    if audio_file:
                        play_and_download_audio(audio_file)
                else:
                    st.warning("No results found on ArXiv.")

    # ------------------ Dataset Search -------------------------
    with tab3:
        st.subheader("Dataset Search")
        ds_searcher = SimpleDatasetSearcher()
        query = st.text_input("Enter dataset search query:")
        run_ds_search = st.button("Search Dataset", key="ds_search_button")
        num_results = st.slider("Max results:", 1, 100, 20, key="ds_max_results")
        
        if run_ds_search and query.strip():
            with st.spinner("Searching dataset..."):
                df = ds_searcher.load_page(0)
                results = simple_dataset_search(query, df)
                if not results.empty:
                    st.write(f"Found {len(results)} results:")
                    shown = 0
                    for i, (_, row) in enumerate(results.iterrows(), 1):
                        if shown >= num_results:
                            break
                        with st.expander(f"Result {i}", expanded=(i==1)):
                            for k, v in row.items():
                                st.write(f"**{k}:** {v}")
                        shown += 1
                else:
                    st.warning("No matching results found.")

    # ------------------ Settings Tab -------------------------
    with tab4:
        st.subheader("Settings")
        if st.button("🗑️ Clear Search History", key="clear_history"):
            # Delete all files
            for fpath in list_saved_inputs():
                os.remove(fpath)
            st.session_state['search_history'] = []
            st.success("Search history cleared for everyone!")
            st.rerun()

if __name__ == "__main__":
    main()