jupyter-agent / app.py
Terry Zhuo
.
c32a030
raw
history blame
7.97 kB
import os
import gradio as gr
from gradio.utils import get_space
from e2b_code_interpreter import Sandbox
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import re
if not get_space():
try:
from dotenv import load_dotenv
load_dotenv()
except (ImportError, ModuleNotFoundError):
pass
from utils import (
run_interactive_notebook,
create_base_notebook,
update_notebook_display,
update_notebook_with_cell,
update_notebook_with_markdown,
)
E2B_API_KEY = os.environ["E2B_API_KEY"]
HF_TOKEN = os.environ["HF_TOKEN"]
DEFAULT_MAX_TOKENS = 512
SANDBOXES = {}
TMP_DIR = './tmp/'
if not os.path.exists(TMP_DIR):
os.makedirs(TMP_DIR)
notebook_data = create_base_notebook([])[0]
with open(TMP_DIR+"jupyter-agent.ipynb", 'w', encoding='utf-8') as f:
json.dump(notebook_data, f, indent=2)
with open("ds-system-prompt.txt", "r") as f:
DEFAULT_SYSTEM_PROMPT = f.read()
# Add this constant at the top with other constants
MAX_TURNS = 10
# Replace the client initialization with local model loading
def load_model_and_tokenizer(model_name="bigcomputer/jupycoder-7b-lora-350"):
if model_name == "bigcomputer/jupycoder-7b-lora-350":
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct")
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
# Function to extract code and text from model response
def parse_model_response(response_text):
cells = []
# Split by code blocks
parts = re.split(r'(```python[\s\S]*?```)', response_text)
for part in parts:
if part.strip():
if part.startswith('```python'):
# Extract code without the markers
code = re.sub(r'```python\n|```', '', part).strip()
cells.append({"type": "code", "content": code})
else:
# Regular text becomes markdown
cells.append({"type": "markdown", "content": part.strip()})
return cells
def execute_jupyter_agent(
system_prompt, user_input, max_new_tokens, model_name, files, message_history, request: gr.Request
):
if request.session_hash not in SANDBOXES:
SANDBOXES[request.session_hash] = Sandbox(api_key=E2B_API_KEY)
sbx = SANDBOXES[request.session_hash]
save_dir = os.path.join(TMP_DIR, request.session_hash)
os.makedirs(save_dir, exist_ok=True)
save_dir = os.path.join(save_dir, 'jupyter-agent.ipynb')
model, tokenizer = load_model_and_tokenizer(model_name)
# Handle file uploads
filenames = []
if files is not None:
for filepath in files:
filpath = Path(filepath)
with open(filepath, "rb") as file:
print(f"uploading {filepath}...")
sbx.files.write(filpath.name, file)
filenames.append(filpath.name)
# Initialize conversation
if len(message_history) == 0:
message_history.append({
"role": "system",
"content": system_prompt.format("- " + "\n- ".join(filenames))
})
message_history.append({"role": "user", "content": user_input})
# Create initial notebook
notebook_data = create_base_notebook([])
turn_count = 0
while turn_count < MAX_TURNS:
turn_count += 1
# Generate response
input_text = "\n".join([msg["content"] for msg in message_history])
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Parse response into cells
cells = parse_model_response(response_text)
# Process each cell
has_code = False
for cell in cells:
if cell["type"] == "code":
has_code = True
# Execute code cell
result = sbx.python.run(cell["content"])
# Add code cell and output to notebook
notebook_data = update_notebook_with_cell(notebook_data, cell["content"], result)
# Add execution result to message history
message_history.append({
"role": "assistant",
"content": cell["content"]
})
message_history.append({
"role": "user",
"content": f"Execution result:\n{result}"
})
else:
# Add markdown cell to notebook
notebook_data = update_notebook_with_markdown(notebook_data, cell["content"])
message_history.append({
"role": "assistant",
"content": cell["content"]
})
# Update display after each cell
notebook_html = update_notebook_display(notebook_data)
yield notebook_html, message_history, save_dir
# If no code was generated or we've reached max turns, stop
if not has_code or turn_count >= MAX_TURNS:
break
# Save final notebook
with open(save_dir, 'w', encoding='utf-8') as f:
json.dump(notebook_data, f, indent=2)
def clear(msg_state):
msg_state = []
return update_notebook_display(create_base_notebook([])[0]), msg_state
css = """
#component-0 {
height: 100vh;
overflow-y: auto;
padding: 20px;
}
.gradio-container {
height: 100vh !important;
}
.contain {
height: 100vh !important;
}
"""
# Create the interface
with gr.Blocks() as demo:
msg_state = gr.State(value=[])
html_output = gr.HTML(value=update_notebook_display(create_base_notebook([])[0]))
user_input = gr.Textbox(
value="Solve the Lotka-Volterra equation and plot the results.", lines=3, label="User input"
)
with gr.Row():
generate_btn = gr.Button("Let's go!")
clear_btn = gr.Button("Clear")
file = gr.File(TMP_DIR+"jupyter-agent.ipynb", label="Download Jupyter Notebook")
with gr.Accordion("Upload files", open=False):
files = gr.File(label="Upload files to use", file_count="multiple")
with gr.Accordion("Advanced Settings", open=False):
system_input = gr.Textbox(
label="System Prompt",
value=DEFAULT_SYSTEM_PROMPT,
elem_classes="input-box",
lines=8,
)
with gr.Row():
max_tokens = gr.Number(
label="Max New Tokens",
value=DEFAULT_MAX_TOKENS,
minimum=128,
maximum=2048,
step=8,
interactive=True,
)
model = gr.Dropdown(
value="bigcomputer/jupycoder-7b-lora-350",
choices=[
"bigcomputer/jupycoder-7b-lora-350",
"Qwen/Qwen2.5-Coder-7B-Instruct"
],
label="Models"
)
generate_btn.click(
fn=execute_jupyter_agent,
inputs=[system_input, user_input, max_tokens, model, files, msg_state],
outputs=[html_output, msg_state, file],
)
clear_btn.click(fn=clear, inputs=[msg_state], outputs=[html_output, msg_state])
demo.load(
fn=None,
inputs=None,
outputs=None,
js=""" () => {
if (document.querySelectorAll('.dark').length) {
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
}
}
"""
)
demo.launch(ssr_mode=False)