import os
import base64
import requests
import gradio as gr
from huggingface_hub import InferenceClient
from dataclasses import dataclass
import pytesseract
from PIL import Image

@dataclass
class ChatMessage:
    """Custom ChatMessage class since huggingface_hub doesn't provide one"""
    role: str
    content: str

    def to_dict(self):
        """Converts ChatMessage to a dictionary for JSON serialization."""
        return {"role": self.role, "content": self.content}

class XylariaChat:
    def __init__(self):
        # Securely load HuggingFace token
        self.hf_token = os.getenv("HF_TOKEN")
        if not self.hf_token:
            raise ValueError("HuggingFace token not found in environment variables")

        # Initialize the inference client with the Qwen model
        self.client = InferenceClient(
            model="Qwen/QwQ-32B-Preview",  # Using the specified model
            api_key=self.hf_token
        )

        # Image captioning API setup
        self.image_api_url = "https://api-inference.huggingface.co/models/microsoft/git-large-coco"
        self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}

        # Initialize conversation history and persistent memory
        self.conversation_history = []
        self.persistent_memory = {}

        # System prompt with more detailed instructions
        self.system_prompt = """You are a helpful and harmless assistant. You are Xylaria developed by Sk Md Saad Amin. You should think step-by-step. You should respond to image questions"""

    def store_information(self, key, value):
        """Store important information in persistent memory"""
        self.persistent_memory[key] = value
        return f"Stored: {key} = {value}"

    def retrieve_information(self, key):
        """Retrieve information from persistent memory"""
        return self.persistent_memory.get(key, "No information found for this key.")

    def reset_conversation(self):
        """
        Completely reset the conversation history, persistent memory,
        and clear API-side memory
        """
        # Clear local memory
        self.conversation_history = []
        self.persistent_memory.clear()

        # Reinitialize the client (not strictly necessary for the API, but can help with local state)
        try:
            self.client = InferenceClient(
                model="Qwen/QwQ-32B-Preview",
                api_key=self.hf_token
            )
        except Exception as e:
            print(f"Error resetting API client: {e}")

        return None  # To clear the chatbot interface

    def caption_image(self, image):
        """
        Caption an uploaded image using Hugging Face API
        Args:
            image (str): Base64 encoded image or file path
        Returns:
            str: Image caption or error message
        """
        try:
            # If image is a file path, read and encode
            if isinstance(image, str) and os.path.isfile(image):
                with open(image, "rb") as f:
                    data = f.read()
            # If image is already base64 encoded
            elif isinstance(image, str):
                # Remove data URI prefix if present
                if image.startswith('data:image'):
                    image = image.split(',')[1]
                data = base64.b64decode(image)
            # If image is a file-like object (unlikely with Gradio, but good to have)
            else:
                data = image.read()

            # Send request to Hugging Face API
            response = requests.post(
                self.image_api_url,
                headers=self.image_api_headers,
                data=data
            )

            # Check response
            if response.status_code == 200:
                caption = response.json()[0].get('generated_text', 'No caption generated')
                return caption
            else:
                return f"Error captioning image: {response.status_code} - {response.text}"

        except Exception as e:
            return f"Error processing image: {str(e)}"

    def perform_math_ocr(self, image_path):
        """
        Perform OCR on an image and return the extracted text.
        Args:
            image_path (str): Path to the image file.
        Returns:
            str: Extracted text from the image, or an error message.
        """
        try:
            # Open the image using Pillow library
            img = Image.open(image_path)

            # Use Tesseract to do OCR on the image
            text = pytesseract.image_to_string(img)

            # Remove leading/trailing whitespace and return
            return text.strip()

        except Exception as e:
            return f"Error during Math OCR: {e}"
        
    def get_response(self, user_input, image=None):
        """
        Generate a response using chat completions with improved error handling
        Args:
            user_input (str): User's message
            image (optional): Uploaded image
        Returns:
            Stream of chat completions or error message
        """
        try:
            # Prepare messages with conversation context and persistent memory
            messages = []

            # Add system prompt as first message
            messages.append(ChatMessage(
                role="system",
                content=self.system_prompt
            ).to_dict())

            # Add persistent memory context if available
            if self.persistent_memory:
                memory_context = "Remembered Information:\n" + "\n".join(
                    [f"{k}: {v}" for k, v in self.persistent_memory.items()]
                )
                messages.append(ChatMessage(
                    role="system",
                    content=memory_context
                ).to_dict())

            # Convert existing conversation history to ChatMessage objects and then to dictionaries
            for msg in self.conversation_history:
                messages.append(ChatMessage(
                    role=msg['role'],
                    content=msg['content']
                ).to_dict())

            # Process image if uploaded
            if image:
                image_caption = self.caption_image(image)
                user_input = f"Uploaded image : {image_caption}\n\nUser's message: {user_input}"

            # Add user input
            messages.append(ChatMessage(
                role="user",
                content=user_input
            ).to_dict())

            # Calculate available tokens
            input_tokens = sum(len(msg['content'].split()) for msg in messages)
            max_new_tokens = 16384 - input_tokens - 50 # Reserve some tokens for safety

            # Limit max_new_tokens to prevent exceeding the total limit
            max_new_tokens = min(max_new_tokens, 10020)

            # Generate response with streaming
            stream = self.client.chat_completion(
                messages=messages,
                model="Qwen/QwQ-32B-Preview",
                temperature=0.7,
                max_tokens=max_new_tokens,
                top_p=0.9,
                stream=True
            )
            
            return stream
        
        except Exception as e:
            print(f"Detailed error in get_response: {e}")
            return f"Error generating response: {str(e)}"

    def messages_to_prompt(self, messages):
        """
        Convert a list of ChatMessage dictionaries to a single prompt string.
        
        This is a simple implementation and you might need to adjust it 
        based on the specific requirements of the model you are using.
        """
        prompt = ""
        for msg in messages:
            if msg["role"] == "system":
                prompt += f"<|system|>\n{msg['content']}<|end|>\n"
            elif msg["role"] == "user":
                prompt += f"<|user|>\n{msg['content']}<|end|>\n"
            elif msg["role"] == "assistant":
                prompt += f"<|assistant|>\n{msg['content']}<|end|>\n"
        prompt += "<|assistant|>\n"  # Start of assistant's turn
        return prompt
        
    
    def create_interface(self):
        def streaming_response(message, chat_history, image_filepath, math_ocr_image_path):
            
            ocr_text = ""
            if math_ocr_image_path:
                ocr_text = self.perform_math_ocr(math_ocr_image_path)
                if ocr_text.startswith("Error"):
                    # Handle OCR error
                    updated_history = chat_history + [[message, ocr_text]]
                    yield "", updated_history, None, None
                    return
                else:
                    message = f"Math OCR Result: {ocr_text}\n\nUser's message: {message}"

            # Check if an image was actually uploaded
            if image_filepath:
                response_stream = self.get_response(message, image_filepath)
            else:
                response_stream = self.get_response(message)
                

            # Handle errors in get_response
            if isinstance(response_stream, str):
                # Return immediately with the error message
                updated_history = chat_history + [[message, response_stream]]
                yield "", updated_history, None, None
                return

            # Prepare for streaming response
            full_response = ""
            updated_history = chat_history + [[message, ""]]

            # Streaming output
            try:
                for chunk in response_stream:
                    if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
                        chunk_content = chunk.choices[0].delta.content
                        full_response += chunk_content
                        
                        # Update the last message in chat history with partial response
                        updated_history[-1][1] = full_response
                        yield "", updated_history, None, None
            except Exception as e:
                print(f"Streaming error: {e}")
                # Display error in the chat interface
                updated_history[-1][1] = f"Error during response: {e}"
                yield "", updated_history, None, None
                return

            # Update conversation history
            self.conversation_history.append(
                {"role": "user", "content": message}
            )
            self.conversation_history.append(
                {"role": "assistant", "content": full_response}
            )

            # Limit conversation history
            if len(self.conversation_history) > 10:
                self.conversation_history = self.conversation_history[-10:]

        # Custom CSS for Inter font and improved styling
        custom_css = """
        @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
        body, .gradio-container {
            font-family: 'Inter', sans-serif !important;
        }
        .chatbot-container .message {
            font-family: 'Inter', sans-serif !important;
        }
        .gradio-container input,
        .gradio-container textarea,
        .gradio-container button {
            font-family: 'Inter', sans-serif !important;
        }
        /* Image Upload Styling */
        .image-container {
            border: 1px solid #ccc;
            border-radius: 8px;
            padding: 10px;
            margin-bottom: 10px;
            display: flex;
            flex-direction: column;
            align-items: center;
            gap: 10px;
            background-color: #f8f8f8;
        }
        .image-preview {
            max-width: 200px;
            max-height: 200px;
            border-radius: 8px;
        }
        .image-buttons {
            display: flex;
            gap: 10px;
        }
        .image-buttons button {
            padding: 8px 15px;
            border-radius: 5px;
            background-color: #4CAF50;
            color: white;
            border: none;
            cursor: pointer;
        }
        .image-buttons button:hover {
            background-color: #367c39;
        }
        """

        with gr.Blocks(theme='soft', css=custom_css) as demo:
            # Chat interface with improved styling
            with gr.Column():
                chatbot = gr.Chatbot(
                    label="Xylaria 1.5 Senoa (EXPERIMENTAL)",
                    height=500,
                    show_copy_button=True,
                )

                # Enhanced Image Upload Section
                with gr.Accordion("Image Input", open=False):
                    with gr.Column() as image_container:  # Use a Column for the image container
                        img = gr.Image(
                            sources=["upload", "webcam"],
                            type="filepath",
                            label="",  # Remove label as it's redundant
                            elem_classes="image-preview",  # Add a class for styling
                        )
                        with gr.Row():
                            clear_image_btn = gr.Button("Clear Image")
                
                with gr.Accordion("Math Input", open=False):
                    with gr.Column():
                        math_ocr_img = gr.Image(
                            sources=["upload", "webcam"],
                            type="filepath",
                            label="Upload Image for math",
                            elem_classes="image-preview"
                        )
                        with gr.Row():
                            clear_math_ocr_btn = gr.Button("Clear Math Image")

                # Input row with improved layout
                with gr.Row():
                    with gr.Column(scale=4):
                        txt = gr.Textbox(
                            show_label=False,
                            placeholder="Type your message...",
                            container=False
                        )
                    btn = gr.Button("Send", scale=1)

                # Clear history and memory buttons
                with gr.Row():
                    clear = gr.Button("Clear Conversation")
                    clear_memory = gr.Button("Clear Memory")

                # Clear image functionality
                clear_image_btn.click(
                    fn=lambda: None,
                    inputs=None,
                    outputs=[img],
                    queue=False
                )

                # Clear Math OCR image functionality
                clear_math_ocr_btn.click(
                    fn=lambda: None,
                    inputs=None,
                    outputs=[math_ocr_img],
                    queue=False
                )

                # Submit functionality with streaming and image support
                btn.click(
                    fn=streaming_response,
                    inputs=[txt, chatbot, img, math_ocr_img],
                    outputs=[txt, chatbot, img, math_ocr_img]
                )
                txt.submit(
                    fn=streaming_response,
                    inputs=[txt, chatbot, img, math_ocr_img],
                    outputs=[txt, chatbot, img, math_ocr_img]
                )

                # Clear conversation history
                clear.click(
                    fn=lambda: None,
                    inputs=None,
                    outputs=[chatbot],
                    queue=False
                )

                # Clear persistent memory and reset conversation
                clear_memory.click(
                    fn=self.reset_conversation,
                    inputs=None,
                    outputs=[chatbot],
                    queue=False
                )

                # Ensure memory is cleared when the interface is closed
                demo.load(self.reset_conversation, None, None)

        return demo

# Launch the interface
def main():
    chat = XylariaChat()
    interface = chat.create_interface()
    interface.launch(
        share=True,  # Optional: create a public link
        debug=True   # Show detailed errors
    )

if __name__ == "__main__":
    main()