File size: 13,866 Bytes
f7a9983
 
 
 
 
e19a951
f7a9983
e19a951
371a048
f7a9983
 
 
e19a951
f7a9983
371a048
 
f7a9983
5c34853
 
 
 
f7a9983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100570e
 
dfa06cf
 
a73b1a6
100570e
 
 
 
 
 
6fef5b1
 
 
 
 
 
 
 
 
 
f7a9983
 
 
 
 
 
 
7b6df75
 
 
 
 
 
 
2018677
 
 
 
 
f7a9983
 
 
 
 
 
 
 
 
 
 
 
 
7b6df75
 
 
 
 
e9163a2
 
 
 
 
 
 
 
 
100570e
 
e9163a2
f7a9983
 
 
100570e
 
 
 
 
f7a9983
e19a951
 
 
 
 
 
 
 
 
 
 
 
 
 
100570e
 
926febf
100570e
 
 
 
 
 
f7a9983
100570e
 
 
571d707
100570e
 
926febf
100570e
 
 
 
 
 
 
 
 
 
f7a9983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2018677
f7a9983
 
e19a951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
033f69a
e19a951
 
033f69a
e19a951
 
 
 
 
 
 
 
 
 
 
 
 
033f69a
 
e19a951
 
 
 
033f69a
e19a951
 
 
 
 
 
 
371a048
e19a951
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7a9983
e19a951
 
fcd3706
 
 
 
 
 
 
371a048
fcd3706
 
 
371a048
fcd3706
 
 
 
 
 
 
 
 
 
371a048
fcd3706
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
import nbformat
from nbformat.v4 import new_notebook, new_markdown_cell, new_code_cell
from nbconvert import HTMLExporter
from huggingface_hub import InferenceClient
from e2b_code_interpreter import Sandbox
from vllm.lora.request import LoRARequest
from traitlets.config import Config
from vllm import LLM
import re

config = Config()
html_exporter = HTMLExporter(config=config, template_name="classic")
BASE_MODEL = LLM(model="Qwen/Qwen2.5-Coder-7B-Instruct", enable_lora=True)

# Constants
MAX_TURNS = 10

with open("llama3_template.jinja", "r") as f:
    llama_template = f.read() 


def parse_exec_result_nb(execution):
    """Convert an E2B Execution object to Jupyter notebook cell output format"""
    outputs = []
    
    if execution.logs.stdout:
        outputs.append({
            'output_type': 'stream',
            'name': 'stdout',
            'text': ''.join(execution.logs.stdout)
        })
    
    if execution.logs.stderr:
        outputs.append({
            'output_type': 'stream',
            'name': 'stderr',
            'text': ''.join(execution.logs.stderr)
        })

    if execution.error:
        outputs.append({
            'output_type': 'error',
            'ename': execution.error.name,
            'evalue': execution.error.value,
            'traceback': [line for line in execution.error.traceback.split('\n')]
        })

    for result in execution.results:
        output = {
            'output_type': 'execute_result' if result.is_main_result else 'display_data',
            'metadata': {},
            'data': {}
        }
        
        if result.text:
            output['data']['text/plain'] = [result.text]  # Array for text/plain
        if result.html:
            output['data']['text/html'] = result.html
        if result.png:
            output['data']['image/png'] = result.png
        if result.svg:
            output['data']['image/svg+xml'] = result.svg
        if result.jpeg:
            output['data']['image/jpeg'] = result.jpeg
        if result.pdf:
            output['data']['application/pdf'] = result.pdf
        if result.latex:
            output['data']['text/latex'] = result.latex
        if result.json:
            output['data']['application/json'] = result.json
        if result.javascript:
            output['data']['application/javascript'] = result.javascript

        if result.is_main_result and execution.execution_count is not None:
            output['execution_count'] = execution.execution_count

        if output['data']:
            outputs.append(output)

    return outputs


system_template = """\
<details>
  <summary style="display: flex; align-items: center;">
    <div class="alert alert-block alert-info" style="margin: 0; width: 100%;">
      <b>System: <span class="arrow">▶</span></b>
    </div>
  </summary>
  <div class="alert alert-block alert-info">
    {}
  </div>
</details>

<style>
details > summary .arrow {{
  display: inline-block;
  transition: transform 0.2s;
}}
details[open] > summary .arrow {{
  transform: rotate(90deg);
}}
</style>
"""

user_template = """<div class="alert alert-block alert-success">
<b>User:</b> {}
</div>
"""

header_message = """<p align="center">
  <img src="https://huggingface.co/spaces/lvwerra/jupyter-agent/resolve/main/jupyter-agent.png" />
</p>


<p style="text-align:center;">Let a LLM agent write and execute code inside a notebook!</p>"""

bad_html_bad = """input[type="file"] {
  display: block;
}"""


def create_base_notebook(messages):
    base_notebook = {
        "metadata": {
            "kernel_info": {"name": "python3"},
            "language_info": {
                "name": "python",
                "version": "3.12",
            },
        },
        "nbformat": 4,
        "nbformat_minor": 0,
        "cells": []
    }
    base_notebook["cells"].append({
            "cell_type": "markdown",
            "metadata": {},
            "source": header_message
            })

    if len(messages)==0:
        base_notebook["cells"].append({
                            "cell_type": "code",
                            "execution_count": None,
                            "metadata": {},
                            "source": "",
                            "outputs": []
                        })

    code_cell_counter = 0
    
    for message in messages:
        if message["role"] == "system":
            text = system_template.format(message["content"].replace('\n', '<br>'))
            base_notebook["cells"].append({
                "cell_type": "markdown",
                "metadata": {},
                "source": text
                })
        elif message["role"] == "user":
            # Check if this is an actual user prompt (has is_user_prompt flag)
            if message.get("is_user_prompt", False):
                text = user_template.format(message["content"].replace('\n', '<br>'))
                base_notebook["cells"].append({
                    "cell_type": "markdown",
                    "metadata": {},
                    "source": text
                    })
            else:
                # This is an execution output, add as code cell output
                base_notebook["cells"][-1]["outputs"].append({
                    "output_type": "stream",
                    "name": "stdout",
                    "text": message["content"]
                })

        elif message["role"] == "assistant" and "tool_calls" in message:
            base_notebook["cells"].append({
                "cell_type": "code",
                "execution_count": None,
                "metadata": {},
                "source": message["content"],
                "outputs": []
            })

        elif message["role"] == "ipython":
            code_cell_counter +=1
            base_notebook["cells"][-1]["outputs"] = message["nbformat"]
            base_notebook["cells"][-1]["execution_count"] = code_cell_counter

        elif message["role"] == "assistant" and "tool_calls" not in message:
            base_notebook["cells"].append({
                "cell_type": "markdown",
                "metadata": {},
                "source": message["content"]
            })
            
        else:
            raise ValueError(message)
        
    return base_notebook, code_cell_counter

def execute_code(sbx, code):
    execution = sbx.run_code(code, on_stdout=lambda data: print('stdout:', data))
    output = ""
    if len(execution.logs.stdout) > 0:
        output += "\n".join(execution.logs.stdout)
    if len(execution.logs.stderr) > 0:
        output += "\n".join(execution.logs.stderr)
    if execution.error is not None:
        output += execution.error.traceback
    return output, execution


def parse_exec_result_llm(execution):
    output = ""
    if len(execution.logs.stdout) > 0:
        output += "\n".join(execution.logs.stdout)
    if len(execution.logs.stderr) > 0:
        output += "\n".join(execution.logs.stderr)
    if execution.error is not None:
        output += execution.error.traceback
    return output
    
    
def update_notebook_display(notebook_data):
    notebook = nbformat.from_dict(notebook_data)
    notebook_body, _ = html_exporter.from_notebook_node(notebook)
    notebook_body = notebook_body.replace(bad_html_bad, "")
    return notebook_body

def run_interactive_notebook(lora_path, sampling_params, messages, sbx, notebook_data=None, max_new_tokens=512):
    """
    Run interactive notebook with model.
    
    Args:
        lora_path: Path to LoRA adapter
        sampling_params: Sampling parameters for the model
        messages: List of conversation messages
        sbx: Sandbox environment for code execution
        notebook_data: Existing notebook data when continuing a session
        max_new_tokens: Maximum number of new tokens to generate
    """
    # For first run or when notebook_data is not provided
    if notebook_data is None:
        # Create a separate list for display messages with is_user_prompt flag
        display_messages = []
        model_messages = []  # Clean messages for model
        for msg in messages:
            display_msg = msg.copy()
            if msg["role"] == "user":
                display_msg["is_user_prompt"] = True
            display_messages.append(display_msg)
            model_messages.append(msg.copy())  # Keep clean copy for model
        notebook_data, code_cell_counter = create_base_notebook(display_messages)
    else:
        # For subsequent runs, use existing messages but clean them for model
        display_messages = messages
        model_messages = []
        for msg in messages:
            # Create clean copy without display flags for model
            model_msg = msg.copy()
            if "is_user_prompt" in model_msg:
                del model_msg["is_user_prompt"]
            model_messages.append(model_msg)
            
        # Find the last code cell counter
        code_cell_counter = 0
        for cell in notebook_data["cells"]:
            if cell["cell_type"] == "code" and cell.get("execution_count"):
                code_cell_counter = max(code_cell_counter, cell["execution_count"])
    
    turns = 0
    while turns < MAX_TURNS:
        turns += 1
        # Generate response using the model with clean messages
        print(model_messages)
        response_stream = BASE_MODEL.chat(
            model_messages,
            sampling_params,
            lora_request=LoRARequest("lora_adapter", 1, lora_path),
            add_generation_prompt=True
        )[0].outputs[0].text
        
        # Check for duplicate responses
        is_duplicate = any(
            msg["role"] == "assistant" and msg["content"].strip() == response_stream.strip()
            for msg in model_messages
        )
        
        if is_duplicate:
            # If duplicate found, yield current state and break
            yield update_notebook_display(notebook_data), notebook_data, display_messages
            break
        
        # Add the full response as an assistant message
        assistant_msg = {
            "role": "assistant",
            "content": response_stream
        }
        model_messages.append(assistant_msg.copy())
        display_messages.append(assistant_msg)
        
        # Check if response contains code block
        code_match = re.search(r'```python\n(.*?)```', response_stream, re.DOTALL)
        if code_match:
            # Extract and execute the code
            code = code_match.group(1).strip()
            code_cell_counter += 1
            
            # Add code cell
            notebook_data["cells"].append({
                "cell_type": "code",
                "execution_count": code_cell_counter,
                "metadata": {},
                "source": code,
                "outputs": []
            })
            
            # Execute code and get results
            exec_result, execution = execute_code(sbx, code)
            
            # Get execution results in notebook format
            outputs = parse_exec_result_nb(execution)
            
            # Create text-only version for user message
            user_content = []
            for output in outputs:
                if output.get('output_type') == 'stream':
                    user_content.append(output['text'])
                elif output.get('output_type') == 'error':
                    user_content.append('\n'.join(output['traceback']))
                elif output.get('output_type') in ['execute_result', 'display_data']:
                    data = output.get('data', {})
                    if 'text/plain' in data:
                        user_content.append('\n'.join(data['text/plain']))
                    if any(key.startswith('image/') for key in data.keys()):
                        user_content.append('<image>')
            
            # Create execution result message
            user_msg = {
                "role": "user", 
                "content": '\n'.join(user_content)
            }
            # Add clean version to model messages
            model_messages.append(user_msg.copy())
            # Add version with display flag to display messages
            display_msg = user_msg.copy()
            display_msg["is_user_prompt"] = False
            display_messages.append(display_msg)
            
            # Update cell with execution results
            notebook_data["cells"][-1]["outputs"] = outputs
            
            # Yield intermediate results after each turn
            yield update_notebook_display(notebook_data), notebook_data, display_messages
        else:
            # No code in this turn, add as markdown and break
            notebook_data["cells"].append({
                "cell_type": "markdown",
                "metadata": {},
                "source": response_stream
            })
            # Yield final results and break
            yield update_notebook_display(notebook_data), notebook_data, display_messages
            break
    
    # Final yield in case we hit MAX_TURNS
    yield update_notebook_display(notebook_data), notebook_data, display_messages

def update_notebook_with_cell(notebook_data, code, output):
    """Add a code cell and its output to the notebook"""
    cell = {
        "cell_type": "code",
        "execution_count": None,
        "metadata": {},
        "source": code,
        "outputs": [{
            "output_type": "stream",
            "name": "stdout",
            "text": str(output)
        }] if output else []
    }
    notebook_data['cells'].append(cell)
    return notebook_data

def update_notebook_with_markdown(notebook_data, markdown_text):
    """Add a markdown cell to the notebook"""
    cell = {
        "cell_type": "markdown",
        "metadata": {},
        "source": markdown_text
    }
    notebook_data['cells'].append(cell)
    return notebook_data