import streamlit as st
from together import Together
import os
from typing import Iterator
from PIL import Image
import base64
from PyPDF2 import PdfReader
import json  # For debugging

API_KEY = os.getenv("TOGETHER_API_KEY")
if not API_KEY:
    raise ValueError("API key is missing! Make sure TOGETHER_API_KEY is set in the Secrets.")

@st.cache_resource
def get_client():
    return Together(api_key=API_KEY)

def process_file(file) -> str:
    if file is None:
        return ""
    
    try:
        if file.type == "application/pdf":
            text = ""
            pdf_reader = PdfReader(file)
            for page in pdf_reader.pages:
                text += page.extract_text() + "\n"
            return text
        elif file.type.startswith("image/"):
            return base64.b64encode(file.getvalue()).decode("utf-8")
        else:
            return file.getvalue().decode('utf-8')
    except Exception as e:
        st.error(f"Error processing file: {str(e)}")
        return ""

def format_message(role: str, content: str) -> dict:
    """Format message according to the API message format."""
    return {
        "role": role,
        "content": content
    }

def get_formatted_history(messages: list) -> list:
    """Convert conversation history to the API message format."""
    formatted_messages = []
    for msg in messages:
        if isinstance(msg, dict) and "role" in msg and "content" in msg:
            # Verify and correct the role if necessary
            role = msg["role"]
            if role not in ["system", "user", "assistant"]:
                role = "user" if role == "human" else "assistant"
            formatted_messages.append(format_message(role, msg["content"]))
    return formatted_messages

def generate_response(
    message: str,
    history: list,
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
    files=None
) -> Iterator[str]:
    client = get_client()
    
    try:
        # Initialize message list
        messages = []
        
        # Add system message
        if system_message.strip():
            messages.append(format_message("system", system_message))
        
        # Add conversation history (user messages already included)
        formatted_history = get_formatted_history(history)
        
        # If file content exists, add it to the last user message
        if files:
            file_contents = []
            for file in files:
                content = process_file(file)
                if content:
                    file_contents.append(f"File content:\n{content}")
            if file_contents:
                if formatted_history and formatted_history[-1]["role"] == "user":
                    formatted_history[-1]["content"] += "\n\n" + "\n\n".join(file_contents)
                else:
                    formatted_history.append(format_message("user", "\n\n".join(file_contents)))
        
        messages.extend(formatted_history)
        
        # Debug: Display API request messages
        st.write("API Request Messages:", json.dumps(messages, ensure_ascii=False, indent=2))
        
        # API Request
        try:
            stream = client.chat.completions.create(
                model="deepseek-ai/DeepSeek-R1",
                messages=messages,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                stream=True
            )
            
            for chunk in stream:
                if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
                    yield chunk.choices[0].delta.content
                    
        except Exception as e:
            if "rate limit" in str(e).lower():
                yield "API call rate limit reached. Please try again later."
            else:
                st.error(f"Detailed API error: {str(e)}")
                yield "Sorry, please try again later."
                
    except Exception as e:
        st.error(f"Detailed error: {str(e)}")
        yield "An error occurred, please try again later."

def main():
    st.set_page_config(page_title="DeepSeek Chat", page_icon="💭", layout="wide")
    
    # Initialize session state
    if "messages" not in st.session_state:
        st.session_state.messages = []

    st.title("DeepSeek Chat")
    st.markdown("Chat with the DeepSeek AI model. You can upload files if needed.")

    with st.sidebar:
        st.header("Settings")
        system_message = st.text_area(
            "System Message",
            value="You are a deeply thoughtful AI. Consider problems thoroughly and derive correct solutions through systematic reasoning. Please answer in English.",
            height=100
        )
        max_tokens = st.slider("Max Tokens", 1, 4096, 2048)
        temperature = st.slider("Temperature", 0.0, 2.0, 0.7, 0.1)
        top_p = st.slider("Top-p", 0.0, 1.0, 0.7, 0.1)
        uploaded_file = st.file_uploader(
            "File Upload (Optional)",
            type=['txt', 'py', 'md', 'pdf', 'png', 'jpg', 'jpeg'],
            accept_multiple_files=True
        )
        st.markdown("Join our Discord community: [https://discord.gg/openfreeai](https://discord.gg/openfreeai)")

    # Display saved conversation messages
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    # Chat input
    if prompt := st.chat_input("What would you like to know?"):
        # Add user message
        user_message = format_message("user", prompt)
        st.session_state.messages.append(user_message)
        
        with st.chat_message("user"):
            st.markdown(prompt)

        # Generate assistant response
        with st.chat_message("assistant"):
            response_placeholder = st.empty()
            full_response = ""
            
            # Call generate_response
            for response_chunk in generate_response(
                prompt,
                st.session_state.messages,
                system_message,
                max_tokens,
                temperature,
                top_p,
                uploaded_file
            ):
                full_response += response_chunk
                response_placeholder.markdown(full_response + "▌")
            
            response_placeholder.markdown(full_response)
        
        # Save assistant response
        assistant_message = format_message("assistant", full_response)
        st.session_state.messages.append(assistant_message)

if __name__ == "__main__":
    main()