Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from gradio.utils import get_space | |
| from e2b_code_interpreter import Sandbox | |
| from pathlib import Path | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import json | |
| import re | |
| if not get_space(): | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except (ImportError, ModuleNotFoundError): | |
| pass | |
| from utils import ( | |
| run_interactive_notebook, | |
| create_base_notebook, | |
| update_notebook_display, | |
| update_notebook_with_cell, | |
| update_notebook_with_markdown, | |
| ) | |
| E2B_API_KEY = os.environ["E2B_API_KEY"] | |
| HF_TOKEN = os.environ["HF_TOKEN"] | |
| DEFAULT_MAX_TOKENS = 512 | |
| SANDBOXES = {} | |
| TMP_DIR = './tmp/' | |
| if not os.path.exists(TMP_DIR): | |
| os.makedirs(TMP_DIR) | |
| notebook_data = create_base_notebook([])[0] | |
| with open(TMP_DIR+"jupyter-agent.ipynb", 'w', encoding='utf-8') as f: | |
| json.dump(notebook_data, f, indent=2) | |
| with open("ds-system-prompt.txt", "r") as f: | |
| DEFAULT_SYSTEM_PROMPT = f.read() | |
| # Add this constant at the top with other constants | |
| MAX_TURNS = 10 | |
| # Replace the client initialization with local model loading | |
| def load_model_and_tokenizer(model_name="bigcomputer/jupycoder-7b-lora-350"): | |
| if model_name == "bigcomputer/jupycoder-7b-lora-350": | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B-Instruct") | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| return model, tokenizer | |
| # Function to extract code and text from model response | |
| def parse_model_response(response_text): | |
| cells = [] | |
| # Split by code blocks | |
| parts = re.split(r'(```python[\s\S]*?```)', response_text) | |
| for part in parts: | |
| if part.strip(): | |
| if part.startswith('```python'): | |
| # Extract code without the markers | |
| code = re.sub(r'```python\n|```', '', part).strip() | |
| cells.append({"type": "code", "content": code}) | |
| else: | |
| # Regular text becomes markdown | |
| cells.append({"type": "markdown", "content": part.strip()}) | |
| return cells | |
| def execute_jupyter_agent( | |
| system_prompt, user_input, max_new_tokens, model_name, files, message_history, request: gr.Request | |
| ): | |
| if request.session_hash not in SANDBOXES: | |
| SANDBOXES[request.session_hash] = Sandbox(api_key=E2B_API_KEY) | |
| sbx = SANDBOXES[request.session_hash] | |
| save_dir = os.path.join(TMP_DIR, request.session_hash) | |
| os.makedirs(save_dir, exist_ok=True) | |
| save_dir = os.path.join(save_dir, 'jupyter-agent.ipynb') | |
| model, tokenizer = load_model_and_tokenizer(model_name) | |
| # Handle file uploads | |
| filenames = [] | |
| if files is not None: | |
| for filepath in files: | |
| filpath = Path(filepath) | |
| with open(filepath, "rb") as file: | |
| print(f"uploading {filepath}...") | |
| sbx.files.write(filpath.name, file) | |
| filenames.append(filpath.name) | |
| # Initialize conversation | |
| if len(message_history) == 0: | |
| message_history.append({ | |
| "role": "system", | |
| "content": system_prompt.format("- " + "\n- ".join(filenames)) | |
| }) | |
| message_history.append({"role": "user", "content": user_input}) | |
| # Create initial notebook | |
| notebook_data = create_base_notebook([]) | |
| turn_count = 0 | |
| while turn_count < MAX_TURNS: | |
| turn_count += 1 | |
| # Generate response | |
| input_text = "\n".join([msg["content"] for msg in message_history]) | |
| inputs = tokenizer(input_text, return_tensors="pt").to(model.device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| ) | |
| response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Parse response into cells | |
| cells = parse_model_response(response_text) | |
| # Process each cell | |
| has_code = False | |
| for cell in cells: | |
| if cell["type"] == "code": | |
| has_code = True | |
| # Execute code cell | |
| result = sbx.python.run(cell["content"]) | |
| # Add code cell and output to notebook | |
| notebook_data = update_notebook_with_cell(notebook_data, cell["content"], result) | |
| # Add execution result to message history | |
| message_history.append({ | |
| "role": "assistant", | |
| "content": cell["content"] | |
| }) | |
| message_history.append({ | |
| "role": "user", | |
| "content": f"Execution result:\n{result}" | |
| }) | |
| else: | |
| # Add markdown cell to notebook | |
| notebook_data = update_notebook_with_markdown(notebook_data, cell["content"]) | |
| message_history.append({ | |
| "role": "assistant", | |
| "content": cell["content"] | |
| }) | |
| # Update display after each cell | |
| notebook_html = update_notebook_display(notebook_data) | |
| yield notebook_html, message_history, save_dir | |
| # If no code was generated or we've reached max turns, stop | |
| if not has_code or turn_count >= MAX_TURNS: | |
| break | |
| # Save final notebook | |
| with open(save_dir, 'w', encoding='utf-8') as f: | |
| json.dump(notebook_data, f, indent=2) | |
| def clear(msg_state): | |
| msg_state = [] | |
| return update_notebook_display(create_base_notebook([])[0]), msg_state | |
| css = """ | |
| #component-0 { | |
| height: 100vh; | |
| overflow-y: auto; | |
| padding: 20px; | |
| } | |
| .gradio-container { | |
| height: 100vh !important; | |
| } | |
| .contain { | |
| height: 100vh !important; | |
| } | |
| """ | |
| # Create the interface | |
| with gr.Blocks() as demo: | |
| msg_state = gr.State(value=[]) | |
| html_output = gr.HTML(value=update_notebook_display(create_base_notebook([])[0])) | |
| user_input = gr.Textbox( | |
| value="Solve the Lotka-Volterra equation and plot the results.", lines=3, label="User input" | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Let's go!") | |
| clear_btn = gr.Button("Clear") | |
| file = gr.File(TMP_DIR+"jupyter-agent.ipynb", label="Download Jupyter Notebook") | |
| with gr.Accordion("Upload files", open=False): | |
| files = gr.File(label="Upload files to use", file_count="multiple") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| system_input = gr.Textbox( | |
| label="System Prompt", | |
| value=DEFAULT_SYSTEM_PROMPT, | |
| elem_classes="input-box", | |
| lines=8, | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Number( | |
| label="Max New Tokens", | |
| value=DEFAULT_MAX_TOKENS, | |
| minimum=128, | |
| maximum=2048, | |
| step=8, | |
| interactive=True, | |
| ) | |
| model = gr.Dropdown( | |
| value="bigcomputer/jupycoder-7b-lora-350", | |
| choices=[ | |
| "bigcomputer/jupycoder-7b-lora-350", | |
| "Qwen/Qwen2.5-Coder-7B-Instruct" | |
| ], | |
| label="Models" | |
| ) | |
| generate_btn.click( | |
| fn=execute_jupyter_agent, | |
| inputs=[system_input, user_input, max_tokens, model, files, msg_state], | |
| outputs=[html_output, msg_state, file], | |
| ) | |
| clear_btn.click(fn=clear, inputs=[msg_state], outputs=[html_output, msg_state]) | |
| demo.load( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| js=""" () => { | |
| if (document.querySelectorAll('.dark').length) { | |
| document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark')); | |
| } | |
| } | |
| """ | |
| ) | |
| demo.launch(ssr_mode=False) | |