import nbformat
from nbformat.v4 import new_notebook, new_markdown_cell, new_code_cell
from nbconvert import HTMLExporter
from huggingface_hub import InferenceClient
from e2b_code_interpreter import Sandbox
from vllm.lora.request import LoRARequest
from traitlets.config import Config
from vllm import LLM
import re

config = Config()
html_exporter = HTMLExporter(config=config, template_name="classic")
BASE_MODEL = LLM(model="Qwen/Qwen2.5-Coder-7B-Instruct", enable_lora=True)

# Constants
MAX_TURNS = 10

with open("llama3_template.jinja", "r") as f:
    llama_template = f.read() 


def parse_exec_result_nb(execution):
    """Convert an E2B Execution object to Jupyter notebook cell output format"""
    outputs = []
    
    if execution.logs.stdout:
        outputs.append({
            'output_type': 'stream',
            'name': 'stdout',
            'text': ''.join(execution.logs.stdout)
        })
    
    if execution.logs.stderr:
        outputs.append({
            'output_type': 'stream',
            'name': 'stderr',
            'text': ''.join(execution.logs.stderr)
        })

    if execution.error:
        outputs.append({
            'output_type': 'error',
            'ename': execution.error.name,
            'evalue': execution.error.value,
            'traceback': [line for line in execution.error.traceback.split('\n')]
        })

    for result in execution.results:
        output = {
            'output_type': 'execute_result' if result.is_main_result else 'display_data',
            'metadata': {},
            'data': {}
        }
        
        if result.text:
            output['data']['text/plain'] = [result.text]  # Array for text/plain
        if result.html:
            output['data']['text/html'] = result.html
        if result.png:
            output['data']['image/png'] = result.png
        if result.svg:
            output['data']['image/svg+xml'] = result.svg
        if result.jpeg:
            output['data']['image/jpeg'] = result.jpeg
        if result.pdf:
            output['data']['application/pdf'] = result.pdf
        if result.latex:
            output['data']['text/latex'] = result.latex
        if result.json:
            output['data']['application/json'] = result.json
        if result.javascript:
            output['data']['application/javascript'] = result.javascript

        if result.is_main_result and execution.execution_count is not None:
            output['execution_count'] = execution.execution_count

        if output['data']:
            outputs.append(output)

    return outputs


system_template = """\
<details>
  <summary style="display: flex; align-items: center;">
    <div class="alert alert-block alert-info" style="margin: 0; width: 100%;">
      <b>System: <span class="arrow">▶</span></b>
    </div>
  </summary>
  <div class="alert alert-block alert-info">
    {}
  </div>
</details>

<style>
details > summary .arrow {{
  display: inline-block;
  transition: transform 0.2s;
}}
details[open] > summary .arrow {{
  transform: rotate(90deg);
}}
</style>
"""

user_template = """<div class="alert alert-block alert-success">
<b>User:</b> {}
</div>
"""

header_message = """<p align="center">
  <img src="https://huggingface.co/spaces/lvwerra/jupyter-agent/resolve/main/jupyter-agent.png" />
</p>


<p style="text-align:center;">Let a LLM agent write and execute code inside a notebook!</p>"""

bad_html_bad = """input[type="file"] {
  display: block;
}"""


def create_base_notebook(messages):
    base_notebook = {
        "metadata": {
            "kernel_info": {"name": "python3"},
            "language_info": {
                "name": "python",
                "version": "3.12",
            },
        },
        "nbformat": 4,
        "nbformat_minor": 0,
        "cells": []
    }
    base_notebook["cells"].append({
            "cell_type": "markdown",
            "metadata": {},
            "source": header_message
            })

    if len(messages)==0:
        base_notebook["cells"].append({
                            "cell_type": "code",
                            "execution_count": None,
                            "metadata": {},
                            "source": "",
                            "outputs": []
                        })

    code_cell_counter = 0
    
    for message in messages:
        if message["role"] == "system":
            text = system_template.format(message["content"].replace('\n', '<br>'))
            base_notebook["cells"].append({
                "cell_type": "markdown",
                "metadata": {},
                "source": text
                })
        elif message["role"] == "user":
            # Check if this is an actual user prompt (has is_user_prompt flag)
            if message.get("is_user_prompt", False):
                text = user_template.format(message["content"].replace('\n', '<br>'))
                base_notebook["cells"].append({
                    "cell_type": "markdown",
                    "metadata": {},
                    "source": text
                    })
            else:
                # This is an execution output, add as code cell output
                base_notebook["cells"][-1]["outputs"].append({
                    "output_type": "stream",
                    "name": "stdout",
                    "text": message["content"]
                })

        elif message["role"] == "assistant" and "tool_calls" in message:
            base_notebook["cells"].append({
                "cell_type": "code",
                "execution_count": None,
                "metadata": {},
                "source": message["content"],
                "outputs": []
            })

        elif message["role"] == "ipython":
            code_cell_counter +=1
            base_notebook["cells"][-1]["outputs"] = message["nbformat"]
            base_notebook["cells"][-1]["execution_count"] = code_cell_counter

        elif message["role"] == "assistant" and "tool_calls" not in message:
            base_notebook["cells"].append({
                "cell_type": "markdown",
                "metadata": {},
                "source": message["content"]
            })
            
        else:
            raise ValueError(message)
        
    return base_notebook, code_cell_counter

def execute_code(sbx, code):
    execution = sbx.run_code(code, on_stdout=lambda data: print('stdout:', data))
    output = ""
    if len(execution.logs.stdout) > 0:
        output += "\n".join(execution.logs.stdout)
    if len(execution.logs.stderr) > 0:
        output += "\n".join(execution.logs.stderr)
    if execution.error is not None:
        output += execution.error.traceback
    return output, execution


def parse_exec_result_llm(execution):
    output = ""
    if len(execution.logs.stdout) > 0:
        output += "\n".join(execution.logs.stdout)
    if len(execution.logs.stderr) > 0:
        output += "\n".join(execution.logs.stderr)
    if execution.error is not None:
        output += execution.error.traceback
    return output
    
    
def update_notebook_display(notebook_data):
    notebook = nbformat.from_dict(notebook_data)
    notebook_body, _ = html_exporter.from_notebook_node(notebook)
    notebook_body = notebook_body.replace(bad_html_bad, "")
    return notebook_body

def run_interactive_notebook(lora_path, sampling_params, messages, sbx, notebook_data=None, max_new_tokens=512):
    """
    Run interactive notebook with model.
    
    Args:
        lora_path: Path to LoRA adapter
        sampling_params: Sampling parameters for the model
        messages: List of conversation messages
        sbx: Sandbox environment for code execution
        notebook_data: Existing notebook data when continuing a session
        max_new_tokens: Maximum number of new tokens to generate
    """
    # For first run or when notebook_data is not provided
    if notebook_data is None:
        # Create a separate list for display messages with is_user_prompt flag
        display_messages = []
        model_messages = []  # Clean messages for model
        for msg in messages:
            display_msg = msg.copy()
            if msg["role"] == "user":
                display_msg["is_user_prompt"] = True
            display_messages.append(display_msg)
            model_messages.append(msg.copy())  # Keep clean copy for model
        notebook_data, code_cell_counter = create_base_notebook(display_messages)
    else:
        # For subsequent runs, use existing messages but clean them for model
        display_messages = messages
        model_messages = []
        for msg in messages:
            # Create clean copy without display flags for model
            model_msg = msg.copy()
            if "is_user_prompt" in model_msg:
                del model_msg["is_user_prompt"]
            model_messages.append(model_msg)
            
        # Find the last code cell counter
        code_cell_counter = 0
        for cell in notebook_data["cells"]:
            if cell["cell_type"] == "code" and cell.get("execution_count"):
                code_cell_counter = max(code_cell_counter, cell["execution_count"])
    
    turns = 0
    while turns < MAX_TURNS:
        turns += 1
        # Generate response using the model with clean messages
        print(model_messages)
        response_stream = BASE_MODEL.chat(
            model_messages,
            sampling_params,
            lora_request=LoRARequest("lora_adapter", 1, lora_path),
            add_generation_prompt=True
        )[0].outputs[0].text
        
        # Check for duplicate responses
        is_duplicate = any(
            msg["role"] == "assistant" and msg["content"].strip() == response_stream.strip()
            for msg in model_messages
        )
        
        if is_duplicate:
            # If duplicate found, yield current state and break
            yield update_notebook_display(notebook_data), notebook_data, display_messages
            break
        
        # Add the full response as an assistant message
        assistant_msg = {
            "role": "assistant",
            "content": response_stream
        }
        model_messages.append(assistant_msg.copy())
        display_messages.append(assistant_msg)
        
        # Check if response contains code block
        code_match = re.search(r'```python\n(.*?)```', response_stream, re.DOTALL)
        if code_match:
            # Extract and execute the code
            code = code_match.group(1).strip()
            code_cell_counter += 1
            
            # Add code cell
            notebook_data["cells"].append({
                "cell_type": "code",
                "execution_count": code_cell_counter,
                "metadata": {},
                "source": code,
                "outputs": []
            })
            
            # Execute code and get results
            exec_result, execution = execute_code(sbx, code)
            
            # Get execution results in notebook format
            outputs = parse_exec_result_nb(execution)
            
            # Create text-only version for user message
            user_content = []
            for output in outputs:
                if output.get('output_type') == 'stream':
                    user_content.append(output['text'])
                elif output.get('output_type') == 'error':
                    user_content.append('\n'.join(output['traceback']))
                elif output.get('output_type') in ['execute_result', 'display_data']:
                    data = output.get('data', {})
                    if 'text/plain' in data:
                        user_content.append('\n'.join(data['text/plain']))
                    if any(key.startswith('image/') for key in data.keys()):
                        user_content.append('<image>')
            
            # Create execution result message
            user_msg = {
                "role": "user", 
                "content": '\n'.join(user_content)
            }
            # Add clean version to model messages
            model_messages.append(user_msg.copy())
            # Add version with display flag to display messages
            display_msg = user_msg.copy()
            display_msg["is_user_prompt"] = False
            display_messages.append(display_msg)
            
            # Update cell with execution results
            notebook_data["cells"][-1]["outputs"] = outputs
            
            # Yield intermediate results after each turn
            yield update_notebook_display(notebook_data), notebook_data, display_messages
        else:
            # No code in this turn, add as markdown and break
            notebook_data["cells"].append({
                "cell_type": "markdown",
                "metadata": {},
                "source": response_stream
            })
            # Yield final results and break
            yield update_notebook_display(notebook_data), notebook_data, display_messages
            break
    
    # Final yield in case we hit MAX_TURNS
    yield update_notebook_display(notebook_data), notebook_data, display_messages

def update_notebook_with_cell(notebook_data, code, output):
    """Add a code cell and its output to the notebook"""
    cell = {
        "cell_type": "code",
        "execution_count": None,
        "metadata": {},
        "source": code,
        "outputs": [{
            "output_type": "stream",
            "name": "stdout",
            "text": str(output)
        }] if output else []
    }
    notebook_data['cells'].append(cell)
    return notebook_data

def update_notebook_with_markdown(notebook_data, markdown_text):
    """Add a markdown cell to the notebook"""
    cell = {
        "cell_type": "markdown",
        "metadata": {},
        "source": markdown_text
    }
    notebook_data['cells'].append(cell)
    return notebook_data