import nbformat
from nbformat.v4 import new_notebook, new_markdown_cell, new_code_cell
from nbconvert import HTMLExporter
from huggingface_hub import InferenceClient
from e2b_code_interpreter import Sandbox
from vllm.lora.request import LoRARequest
from traitlets.config import Config
from vllm import LLM
import re
config = Config()
html_exporter = HTMLExporter(config=config, template_name="classic")
BASE_MODEL = LLM(model="Qwen/Qwen2.5-Coder-7B-Instruct", enable_lora=True)
# Constants
MAX_TURNS = 10
with open("llama3_template.jinja", "r") as f:
llama_template = f.read()
def parse_exec_result_nb(execution):
"""Convert an E2B Execution object to Jupyter notebook cell output format"""
outputs = []
if execution.logs.stdout:
outputs.append({
'output_type': 'stream',
'name': 'stdout',
'text': ''.join(execution.logs.stdout)
})
if execution.logs.stderr:
outputs.append({
'output_type': 'stream',
'name': 'stderr',
'text': ''.join(execution.logs.stderr)
})
if execution.error:
outputs.append({
'output_type': 'error',
'ename': execution.error.name,
'evalue': execution.error.value,
'traceback': [line for line in execution.error.traceback.split('\n')]
})
for result in execution.results:
output = {
'output_type': 'execute_result' if result.is_main_result else 'display_data',
'metadata': {},
'data': {}
}
if result.text:
output['data']['text/plain'] = [result.text] # Array for text/plain
if result.html:
output['data']['text/html'] = result.html
if result.png:
output['data']['image/png'] = result.png
if result.svg:
output['data']['image/svg+xml'] = result.svg
if result.jpeg:
output['data']['image/jpeg'] = result.jpeg
if result.pdf:
output['data']['application/pdf'] = result.pdf
if result.latex:
output['data']['text/latex'] = result.latex
if result.json:
output['data']['application/json'] = result.json
if result.javascript:
output['data']['application/javascript'] = result.javascript
if result.is_main_result and execution.execution_count is not None:
output['execution_count'] = execution.execution_count
if output['data']:
outputs.append(output)
return outputs
system_template = """\
System: ▶
{}
"""
user_template = """
User: {}
"""
header_message = """
Let a LLM agent write and execute code inside a notebook!
"""
bad_html_bad = """input[type="file"] {
display: block;
}"""
def create_base_notebook(messages):
base_notebook = {
"metadata": {
"kernel_info": {"name": "python3"},
"language_info": {
"name": "python",
"version": "3.12",
},
},
"nbformat": 4,
"nbformat_minor": 0,
"cells": []
}
base_notebook["cells"].append({
"cell_type": "markdown",
"metadata": {},
"source": header_message
})
if len(messages)==0:
base_notebook["cells"].append({
"cell_type": "code",
"execution_count": None,
"metadata": {},
"source": "",
"outputs": []
})
code_cell_counter = 0
for message in messages:
if message["role"] == "system":
text = system_template.format(message["content"].replace('\n', '
'))
base_notebook["cells"].append({
"cell_type": "markdown",
"metadata": {},
"source": text
})
elif message["role"] == "user":
# Check if this is an actual user prompt (has is_user_prompt flag)
if message.get("is_user_prompt", False):
text = user_template.format(message["content"].replace('\n', '
'))
base_notebook["cells"].append({
"cell_type": "markdown",
"metadata": {},
"source": text
})
else:
# This is an execution output, add as code cell output
base_notebook["cells"][-1]["outputs"].append({
"output_type": "stream",
"name": "stdout",
"text": message["content"]
})
elif message["role"] == "assistant" and "tool_calls" in message:
base_notebook["cells"].append({
"cell_type": "code",
"execution_count": None,
"metadata": {},
"source": message["content"],
"outputs": []
})
elif message["role"] == "ipython":
code_cell_counter +=1
base_notebook["cells"][-1]["outputs"] = message["nbformat"]
base_notebook["cells"][-1]["execution_count"] = code_cell_counter
elif message["role"] == "assistant" and "tool_calls" not in message:
base_notebook["cells"].append({
"cell_type": "markdown",
"metadata": {},
"source": message["content"]
})
else:
raise ValueError(message)
return base_notebook, code_cell_counter
def execute_code(sbx, code):
execution = sbx.run_code(code, on_stdout=lambda data: print('stdout:', data))
output = ""
if len(execution.logs.stdout) > 0:
output += "\n".join(execution.logs.stdout)
if len(execution.logs.stderr) > 0:
output += "\n".join(execution.logs.stderr)
if execution.error is not None:
output += execution.error.traceback
return output, execution
def parse_exec_result_llm(execution):
output = ""
if len(execution.logs.stdout) > 0:
output += "\n".join(execution.logs.stdout)
if len(execution.logs.stderr) > 0:
output += "\n".join(execution.logs.stderr)
if execution.error is not None:
output += execution.error.traceback
return output
def update_notebook_display(notebook_data):
notebook = nbformat.from_dict(notebook_data)
notebook_body, _ = html_exporter.from_notebook_node(notebook)
notebook_body = notebook_body.replace(bad_html_bad, "")
return notebook_body
def run_interactive_notebook(lora_path, sampling_params, messages, sbx, notebook_data=None, max_new_tokens=512):
"""
Run interactive notebook with model.
Args:
lora_path: Path to LoRA adapter
sampling_params: Sampling parameters for the model
messages: List of conversation messages
sbx: Sandbox environment for code execution
notebook_data: Existing notebook data when continuing a session
max_new_tokens: Maximum number of new tokens to generate
"""
# For first run or when notebook_data is not provided
if notebook_data is None:
# Create a separate list for display messages with is_user_prompt flag
display_messages = []
model_messages = [] # Clean messages for model
for msg in messages:
display_msg = msg.copy()
if msg["role"] == "user":
display_msg["is_user_prompt"] = True
display_messages.append(display_msg)
model_messages.append(msg.copy()) # Keep clean copy for model
notebook_data, code_cell_counter = create_base_notebook(display_messages)
else:
# For subsequent runs, use existing messages but clean them for model
display_messages = messages
model_messages = []
for msg in messages:
# Create clean copy without display flags for model
model_msg = msg.copy()
if "is_user_prompt" in model_msg:
del model_msg["is_user_prompt"]
model_messages.append(model_msg)
# Find the last code cell counter
code_cell_counter = 0
for cell in notebook_data["cells"]:
if cell["cell_type"] == "code" and cell.get("execution_count"):
code_cell_counter = max(code_cell_counter, cell["execution_count"])
turns = 0
while turns < MAX_TURNS:
turns += 1
# Generate response using the model with clean messages
print(model_messages)
response_stream = BASE_MODEL.chat(
model_messages,
sampling_params,
lora_request=LoRARequest("lora_adapter", 1, lora_path),
add_generation_prompt=True
)[0].outputs[0].text
# Check for duplicate responses
is_duplicate = any(
msg["role"] == "assistant" and msg["content"].strip() == response_stream.strip()
for msg in model_messages
)
if is_duplicate:
# If duplicate found, yield current state and break
yield update_notebook_display(notebook_data), notebook_data, display_messages
break
# Add the full response as an assistant message
assistant_msg = {
"role": "assistant",
"content": response_stream
}
model_messages.append(assistant_msg.copy())
display_messages.append(assistant_msg)
# Check if response contains code block
code_match = re.search(r'```python\n(.*?)```', response_stream, re.DOTALL)
if code_match:
# Extract and execute the code
code = code_match.group(1).strip()
code_cell_counter += 1
# Add code cell
notebook_data["cells"].append({
"cell_type": "code",
"execution_count": code_cell_counter,
"metadata": {},
"source": code,
"outputs": []
})
# Execute code and get results
exec_result, execution = execute_code(sbx, code)
# Get execution results in notebook format
outputs = parse_exec_result_nb(execution)
# Create text-only version for user message
user_content = []
for output in outputs:
if output.get('output_type') == 'stream':
user_content.append(output['text'])
elif output.get('output_type') == 'error':
user_content.append('\n'.join(output['traceback']))
elif output.get('output_type') in ['execute_result', 'display_data']:
data = output.get('data', {})
if 'text/plain' in data:
user_content.append('\n'.join(data['text/plain']))
if any(key.startswith('image/') for key in data.keys()):
user_content.append('')
# Create execution result message
user_msg = {
"role": "user",
"content": '\n'.join(user_content)
}
# Add clean version to model messages
model_messages.append(user_msg.copy())
# Add version with display flag to display messages
display_msg = user_msg.copy()
display_msg["is_user_prompt"] = False
display_messages.append(display_msg)
# Update cell with execution results
notebook_data["cells"][-1]["outputs"] = outputs
# Yield intermediate results after each turn
yield update_notebook_display(notebook_data), notebook_data, display_messages
else:
# No code in this turn, add as markdown and break
notebook_data["cells"].append({
"cell_type": "markdown",
"metadata": {},
"source": response_stream
})
# Yield final results and break
yield update_notebook_display(notebook_data), notebook_data, display_messages
break
# Final yield in case we hit MAX_TURNS
yield update_notebook_display(notebook_data), notebook_data, display_messages
def update_notebook_with_cell(notebook_data, code, output):
"""Add a code cell and its output to the notebook"""
cell = {
"cell_type": "code",
"execution_count": None,
"metadata": {},
"source": code,
"outputs": [{
"output_type": "stream",
"name": "stdout",
"text": str(output)
}] if output else []
}
notebook_data['cells'].append(cell)
return notebook_data
def update_notebook_with_markdown(notebook_data, markdown_text):
"""Add a markdown cell to the notebook"""
cell = {
"cell_type": "markdown",
"metadata": {},
"source": markdown_text
}
notebook_data['cells'].append(cell)
return notebook_data