import torch
import spaces
import gradio as gr
from threading import Thread
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    BitsAndBytesConfig, 
    TextIteratorStreamer,
    StoppingCriteria,
    StoppingCriteriaList
)

MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"

DEFAULT_SYSTEM_PROMPT = """You are an Expert Reasoning Assistant. Follow these steps:
[Understand]: Analyze key elements and clarify objectives
[Plan]: Outline step-by-step methodology
[Reason]: Execute plan with detailed analysis
[Verify]: Check logic and evidence
[Conclude]: Present structured conclusion"""

CSS = """
.gr-chatbot { min-height: 500px; border-radius: 15px; }
.special-tag { color: #2ecc71; font-weight: 600; }
footer { display: none !important; }
"""

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        return input_ids[0][-1] == tokenizer.eos_token_id

def initialize_model():
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        device_map="auto",
        quantization_config=quantization_config,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )

    return model, tokenizer

def format_response(text):
    return text.replace("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n') \
              .replace("[Plan]", '\n<strong class="special-tag">[Plan]</strong>\n') \
              .replace("[Conclude]", '\n<strong class="special-tag">[Conclude]</strong>\n')

@spaces.GPU
def generate_response(message, chat_history, system_prompt, temperature, max_tokens):
    # Create conversation history for model
    conversation = [{"role": "system", "content": system_prompt}]
    for user_msg, bot_msg in chat_history:
        conversation.extend([
            {"role": "user", "content": user_msg},
            {"role": "assistant", "content": bot_msg}
        ])
    conversation.append({"role": "user", "content": message})

    # Tokenize input
    input_ids = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    # Setup streaming
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        stopping_criteria=StoppingCriteriaList([StopOnTokens()])
    )

    # Start generation thread
    Thread(target=model.generate, kwargs=generate_kwargs).start()

    # Initialize response buffer
    partial_message = ""
    new_history = chat_history + [(message, "")]
    
    # Stream response
    for new_token in streamer:
        partial_message += new_token
        formatted = format_response(partial_message)
        new_history[-1] = (message, formatted + "▌")
        yield new_history

    # Final update without cursor
    new_history[-1] = (message, format_response(partial_message))
    yield new_history

model, tokenizer = initialize_model()

with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    <h1 align="center">🧠 AI Reasoning Assistant</h1>
    <p align="center">DeepSeek-R1-Distill-Qwen-14B</p>
    """)
    
    chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
    msg = gr.Textbox(label="Your Question", placeholder="Type your question...")
    
    with gr.Accordion("⚙️ Settings", open=False):
        system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
        temperature = gr.Slider(0, 1, value=0.7, label="Creativity")
        max_tokens = gr.Slider(128, 4096, value=2048, label="Max Response Length")

    clear = gr.Button("Clear History")
    
    msg.submit(
        generate_response,
        [msg, chatbot, system_prompt, temperature, max_tokens],
        [chatbot],
        show_progress="hidden"
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue().launch()