import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import datasets
import asyncio
import numpy as np
import torch
from threading import Thread

def make_script(shader_code):
    # code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html
    script = ("""
<!-- Licensed under a BSD license. See license.html for license -->
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
<title>WebGL - Shadertoy</title>
<link type="text/css" href="https://webglfundamentals.org/webgl/resources/webgl-tutorials.css" rel="stylesheet" />
<style>
.divcanvas {
  position: relative;
  display: inline-block;
} 
canvas {
  display: block;
}
.playpause {
  position: absolute;
  left: 10px;
  top: 10px;
  width: 100%;
  height: 100%;
  font-size: 60px;
  justify-content: center;
  align-items: center;
  color: rgba(255, 255, 255, 0.3);
  transition: opacity 0.2s ease-in-out;
}
.playpausehide,
.playpause:hover {
  opacity: 0;
}
.iframe .divcanvas {
  display: block;
}
</style>
</head>
<body>
<div class="divcanvas">
  <canvas id="canvas"></canvas>
  <div class="playpause">▶</div>
</div>
\nblank canvas here indicates that some of the shadertoy specific functions are not yet supported with this implementation (like #define I believe). you can always copy and paste the code into a shadertoy.com window to try.
</body>
<!--
for most samples webgl-utils only provides shader compiling/linking and
canvas resizing because why clutter the examples with code thats the same in every sample.
See https://webglfundamentals.org/webgl/lessons/webgl-boilerplate.html
and https://webglfundamentals.org/webgl/lessons/webgl-resizing-the-canvas.html
for webgl-utils, m3, m4, and webgl-lessons-ui.
-->
<script src="https://webglfundamentals.org/webgl/resources/webgl-utils.js"></script>
<script>
"use strict";

function main() {
  // Get A WebGL context
  /** @type {HTMLCanvasElement} */
  const canvas = document.querySelector("#canvas");
  const gl = canvas.getContext("webgl");
  if (!gl) {
    return;
  }

  const vs = `
    // an attribute will receive data from a buffer
    attribute vec4 a_position;

    // all shaders have a main function
    void main() {

      // gl_Position is a special variable a vertex shader
      // is responsible for setting
      gl_Position = a_position;
    }
  `;

  const fs = `
    precision highp float;

    uniform vec2 iResolution;
    uniform vec2 iMouse;
    uniform float iTime;

    """ + shader_code + """
    void main() {
      mainImage(gl_FragColor, gl_FragCoord.xy);
    }
  `;

  // setup GLSL program
  const program = webglUtils.createProgramFromSources(gl, [vs, fs]);

  // look up where the vertex data needs to go.
  const positionAttributeLocation = gl.getAttribLocation(program, "a_position");

  // look up uniform locations
  const resolutionLocation = gl.getUniformLocation(program, "iResolution");
  const mouseLocation = gl.getUniformLocation(program, "iMouse");
  const timeLocation = gl.getUniformLocation(program, "iTime");

  // Create a buffer to put three 2d clip space points in
  const positionBuffer = gl.createBuffer();

  // Bind it to ARRAY_BUFFER (think of it as ARRAY_BUFFER = positionBuffer)
  gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer);

  // fill it with a 2 triangles that cover clipspace
  gl.bufferData(gl.ARRAY_BUFFER, new Float32Array([
    -1, -1,  // first triangle
     1, -1,
    -1,  1,
    -1,  1,  // second triangle
     1, -1,
     1,  1,
  ]), gl.STATIC_DRAW);

  const playpauseElem = document.querySelector(".playpause");
  const inputElem = document.querySelector(".divcanvas");
  inputElem.addEventListener("mouseover", requestFrame);
  inputElem.addEventListener("mouseout", cancelFrame);

  let mouseX = 0;
  let mouseY = 0;

  function setMousePosition(e) {
    const rect = inputElem.getBoundingClientRect();
    mouseX = e.clientX - rect.left;
    mouseY = rect.height - (e.clientY - rect.top) - 1;  // bottom is 0 in WebGL
  }

  inputElem.addEventListener("mousemove", setMousePosition);
  inputElem.addEventListener("touchstart", (e) => {
    e.preventDefault();
    playpauseElem.classList.add("playpausehide");
    requestFrame();
  }, {passive: false});
  inputElem.addEventListener("touchmove", (e) => {
    e.preventDefault();
    setMousePosition(e.touches[0]);
  }, {passive: false});
  inputElem.addEventListener("touchend", (e) => {
    e.preventDefault();
    playpauseElem.classList.remove("playpausehide");
    cancelFrame();
  }, {passive: false});

  let requestId;
  function requestFrame() {
    if (!requestId) {
      requestId = requestAnimationFrame(render);
    }
  }
  function cancelFrame() {
    if (requestId) {
      cancelAnimationFrame(requestId);
      requestId = undefined;
    }
  }

  let then = 0;
  let time = 0;
  function render(now) {
    requestId = undefined;
    now *= 0.001;  // convert to seconds
    const elapsedTime = Math.min(now - then, 0.1);
    time += elapsedTime;
    then = now;

    webglUtils.resizeCanvasToDisplaySize(gl.canvas);

    // Tell WebGL how to convert from clip space to pixels
    gl.viewport(0, 0, gl.canvas.width, gl.canvas.height);

    // Tell it to use our program (pair of shaders)
    gl.useProgram(program);

    // Turn on the attribute
    gl.enableVertexAttribArray(positionAttributeLocation);

    // Bind the position buffer.
    gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer);

    // Tell the attribute how to get data out of positionBuffer (ARRAY_BUFFER)
    gl.vertexAttribPointer(
        positionAttributeLocation,
        2,          // 2 components per iteration
        gl.FLOAT,   // the data is 32bit floats
        false,      // dont normalize the data
        0,          // 0 = move forward size * sizeof(type) each iteration to get the next position
        0,          // start at the beginning of the buffer
    );

    gl.uniform2f(resolutionLocation, gl.canvas.width, gl.canvas.height);
    gl.uniform2f(mouseLocation, mouseX, mouseY);
    gl.uniform1f(timeLocation, time);

    gl.drawArrays(
        gl.TRIANGLES,
        0,     // offset
        6,     // num vertices to process
    );

    requestFrame();
  }

  requestFrame();
  requestAnimationFrame(cancelFrame);
}

main();
</script>
</html>


""")
    return script

def make_iframe(shader_code): #keep a single function?
     script = make_script(shader_code)
     return f"""<iframe width="640" height="420" srcdoc=\'{script}\' allowfullscreen></iframe>"""
    

intro_text = """
# Welcome to the interactive shadercoding demo.
This gives you access to a filtered version of the [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys) dataset, only shaders that consist of a single pass are available.
And then lets you use code generation models to make alterations to part of the shadercode.

## How To Use:
1. Load any Model for [`text-generation`](https://huggingface.co/models?pipeline_tag=text-generation) and hit ENTER.
2. Use the slider to sample a shader from the dataset.
  - The original shader will be embedding on the left, click on title to get to the source.
  - The shadercode will be displayed on the right, this is interactive.
  - A preview of the currently displayed shadercode will be displayed on the lower left. (hover to advance time)
3. use the dropdown to select a function to modify.
4. press either button to make modifications to that function
5. you can also edit the code manually.
"""

outro_text ="""
## Models to try (look at [ShaderEval](https://huggingface.co/spaces/Vipitis/ShaderEval) for an indication of how helpful they will be):
- [gpt2](https://huggingface.co/gpt2) baseline for language models, really struggles with shadercode.
- [bigscience/bloom-1b1](https://huggingface.co/bigscience/bloom-1b1) a newer and larger freely available model. Does understand a big of code.
- [codeparrot/codeparrot-small](https://huggingface.co/codeparrot/codeparrot-small) a model trained on code, but not on shadercode. Manages to graps the patterns.
- [salesforce/codegen-2B-multi](https://huggingface.co/salesforce/codegen-2B-multi) a larger model that indicates some potential.
- [bigcode/santacoder](https://huggingface.co/bigcode/santacoder) a model trained on subset of [TheStack](https://huggingface.co/datasets/bigcode/the-stack), struggles with shadercode.
- [Vipitis/santacoder-finetuned-the-stack-glsl](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl) fine-tuned by me on the glsl subset of [TheStack](https://huggingface.co/datasets/bigcode/the-stack), is an improvement.	
- [Vipitis/santacoder-finetuned-Shadertoys](https://huggingface.co/Vipitis/santacoder-finetuned-Shadertoys) fine-tuned by me on whole shaders from [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys). Does overfit quite a bit with greedy decoding.
- [Vipitis/santacoder-finetuned-Shadertoys-fine](https://huggingface.co/Vipitis/santacoder-finetuned-Shadertoys-fine) fine-tuned by me just functions from [Shadertoys-fine](https://huggingface.co/datasets/Vipitis/Shadertoys-fine). Memorizes the exact function about half the time.
- [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) a very large model which I haven't tried yet.
- **any other model you want to**

## TODO (feel free to contribute with a [Pull-Request](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl/discussions?status=open&type=pull_request)):
 - [x] use embedded Shadertoy for reference/attribution (done, but some errors)
 - [~] working render implementation on CPU only space (as webgl via webglfundamentals, ccs needs fixing for iframe (or hijack Shadertoy iframe))
 - [~] generate variations of return statements [ShaderEval task1](https://huggingface.co/spaces/Vipitis/ShaderEval) (needs to be reworked using the other parts)
 - [x] generate whole functions (seems to work quite well)
 - [] dropdown for model selection (from curated list or all supported models?)
 - [] generation history stating which function and orig/generated returns. (use State ??). do it as comments in the code?
 - [~] display errros/issues to the user (raise gr.Error could be one idea, but highlighting in the code would be awesome) currently adds a comment to the code.
 - [] generate whole shaders (via prompts guidance, recursive from errors)
 - [x] accordion with generation parameters (as pipeline_kwargs?) look up starcoder playround and take "inspiration" from there (implemented for both buttons, untested)
 - [] support FIM task for better model context
 - [~] include some context for prompt (title, comments before a functions) - now works with the first comment inside a function body (has to be first)
 - [] gradio examples
 - [] use GPU if available, respect memory restrictions.
 - [~] stream model generation (maybe in a new window?) - WIP for body gen right now -> janky solution works.

### Notes:
 - this is meant as a resource to show code generation for a "creative" task.
 - the goal is not to not replace shader artists, but aims to be an assistant instead.
 - the space still lacks quite a lot of features, but will continue to evolve.
 - this demo can be useful to sannity check evaluation results, where the academic numbers are made.
 - If you create a remix with these tools, please attribute the original creator of your starting point when sharing the results. (And perhaps share in the [discussion tab](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl/discussions?status=open&type=discussion) too)
"""

new_shadertoy_code = """void mainImage( out vec4 fragColor, in vec2 fragCoord )
{
    // touch the slider to load a shader from the dataset or start coding from here.
    vec2 uv = fragCoord/iResolution.xy;
    vec3 col = 0.5 + 0.5*cos(iTime+uv.xyx+vec3(0,2,4));
    fragColor = vec4(col,1.0);
}"""

passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions.
all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
num_samples = len(all_single_passes)    

import tree_sitter
from tree_sitter import Language, Parser
Language.build_library("./build/my-languages.so", ['tree-sitter-glsl'])
GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl')
parser = Parser()
parser.set_language(GLSL_LANGUAGE)

def grab_sample(sample_idx):
    sample_pass = all_single_passes[sample_idx]
    sample_code = sample_pass["code"]
    sample_source = sample_pass["source"]
    sample_title = sample_pass["title"]
    sample_auhtor = sample_pass["author"]
    source_iframe = construct_embed(sample_source)
    print(f"{source_iframe=}")
    # sample_funcs = _parse_functions(sample_code)
    # funcs = _parse_functions(sample_code)
    # func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)]
    # print(f"updating drop down to:{func_identifiers}")
    return sample_pass, sample_code, source_iframe, funcs#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor


def _parse_functions(in_code):
    """
    returns all functions in the code as their actual nodes.
    includes any comment made directly after the function definition or diretly after #copilot trigger
    """
    tree = parser.parse(bytes(in_code, "utf8"))
    funcs = [n for n in tree.root_node.children if n.type == "function_definition"]

    return funcs

PIPE = None

def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing
    # if torch.cuda.is_available():
    #     device = "cuda"
    # else:
    #     device = "cpu"
    tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True)
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True) #, device=device)
    PIPE = pipe # set the global?
    print(f"loaded model {model_cp} as a pipline")
    return pipe

def _run_generation(model_ctx:str, pipe, gen_kwargs:dict):
    """
    Text generation function
    Args:
        model_ctx (str): The context to start generation from.
        pipe (Pipeline): The pipeline to use for generation.
        gen_kwargs (dict): The generation kwargs.
    Returns:
        str: The generated text. (it iterates over time)
    """
    # Tokenize the model_context
    model_inputs = pipe.tokenizer(model_ctx, return_tensors="pt")

    # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
    # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
    streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15.0)
    generate_kwargs = dict(model_inputs, streamer=streamer, **gen_kwargs)
    t = Thread(target=pipe.model.generate, kwargs=generate_kwargs)
    t.start()

    # Pull the generated text from the streamer, and update the model output.
    model_output = ""
    for new_text in streamer:
        # print("step", end="")
        model_output += new_text
        yield model_output
    streamer.on_finalized_text("stream reached the end.")
    return model_output #is this ever reached?

def process_retn(retn):
    return retn.split(";")[0].strip()

def get_full_replacement(orig_code, retn_start_idx, retn_end_idx, prediction) -> str:
    """
    Batches the generated return statement into the code and returns the full altered code.
    """
    print(f"{orig_code[retn_start_idx:retn_end_idx]=}")
    generated = process_retn(prediction)
    print(f"{generated=}")
    variation = orig_code[:retn_start_idx] + generated + orig_code[retn_end_idx:]
    return variation

def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE): #default pipeline can't be passed as gloabl?
    """
    Replaces the return statement of a function with a generated one.
    Args:
        orig_code (str): The original code.
        func_idx (int): The index of the function to replace the return statement of.
        pipeline (Pipeline): The pipeline to use for generation.
    Returns:
        str: The altered code.
    """
    if pipeline is None:
        print("no pipeline found, loading default one")
        pipeline = _make_pipeline()
    
    if isinstance(func_idx, str):
        print(f"{func_idx=}")
        func_idx = int(func_idx.split(":")[0].strip())
    elif isinstance(func_idx, int):
        pass
    else:
        raise gr.Error(f"func_idx must be int or str, not {type(func_idx)}")

    generation_kwargs = _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)

    retrns = []
    retrn_start_idx = orig_code.find("return")
    while retrn_start_idx != -1:
        retrn_end_idx = orig_code.find(";", retrn_start_idx)
        retrns.append((retrn_start_idx, retrn_end_idx))
        retrn_start_idx = orig_code.find("return", retrn_end_idx)
    num_returns = len(retrns)
    if num_returns == 0:
        print("no return statement found, returning original code")
        return orig_code
    func_idx = int(max(0, min(func_idx, num_returns - 1))) #clamp to valid range, cast to int as a bodge.
    retrn_start_idx, retrn_end_idx = retrns[func_idx]
    model_context = orig_code[:retrn_start_idx] #TODO: maximal context?
    model_inp = model_context + "return"
    pipe_generation = pipeline(model_inp, return_full_text=False, **generation_kwargs)[0]["generated_text"] #pipeline kwargs are missing?!
    altered_code = get_full_replacement(orig_code, retrn_start_idx+7, retrn_end_idx, pipe_generation)
    
    return altered_code

def _line_chr2char(text, line_idx, chr_idx):
    """
    returns the character index at the given line and character index.
    """
    lines = text.split("\n")
    char_idx = 0
    for i in range(line_idx):
        char_idx += len(lines[i]) + 1
    char_idx += chr_idx
    return char_idx

def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
    gen_kwargs = {}
    gen_kwargs["temperature"] = temperature
    gen_kwargs["max_new_tokens"] = max_new_tokens
    gen_kwargs["top_p"] = top_p
    gen_kwargs["repetition_penalty"] = repetition_penalty
    return gen_kwargs

def alter_body(old_code, func_id, funcs_list: list, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE):
    """
    Replaces the body of a function with a generated one.
    Args:
        old_code (str): The original code.
        func_node (Node): The node of the function to replace the body of.
        pipeline (Pipeline): The pipeline to use for generation.
    Returns:
        str: The altered code.
    """
    if isinstance(func_id, str):
        print(f"{func_id=}")
        func_id = int(func_id.split(":")[0].strip()) #undo their string casting?
    elif isinstance(func_id, int):
        pass
    else:
        raise gr.Error(f"func_id must be int or str, not {type(func_id)}")
    func_node = funcs_list[func_id]
    print(f"using for generation: {func_node=}")
    
    generation_kwargs = _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
    
    print(f"{pipeline=}") # check if default even loaded
    if pipeline is None:
        print("no pipeline found, loading default one")
        pipeline = _make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine")

    func_start_idx = _line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1])
    identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
    body_node = func_node.child_by_field_name("body")
    body_start_idx = _line_chr2char(old_code, body_node.start_point[0], body_node.start_point[1])
    body_end_idx = _line_chr2char(old_code, body_node.end_point[0], body_node.end_point[1])
    print(f"{old_code[body_start_idx:body_end_idx]=}")
    model_context = identifier_str # base case
    # add any comments at the beginning of the function to the model_context
    second_child = func_node.child_by_field_name("body").children[1] #might error out?
    if second_child.type == "comment":
        # print(second_child.text.decode())
        model_context += " { \n  " + second_child.text.decode()
        print(f"{model_context=}")
    # generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
    generation = _run_generation(model_context, pipeline, generation_kwargs)
    for i in generation:
        print(f"{i=}")
        yield model_context + i, pipeline #fix in between, do all the stuff in the end?
    generation = i[:] #seems to work
    print(f"{generation=}")
    ctx_with_generation = model_context + generation
    print(f"{ctx_with_generation=}")
    try:
        #strip the body
        first_gened_func = _parse_functions(ctx_with_generation)[0] # truncate generation to a single function?
    except IndexError:
        print("generation wasn't a full function.")
        altered_code = old_code[:func_start_idx] + model_context + generation + "//the generation didn't complete the function!\n" + old_code[body_end_idx:] #needs a newline to break out of the comment.
        return altered_code, pipeline
        # raise gr.Error(f"didn't generate a full function: {generation!r}]")
    print(f"{first_gened_func=}")
    generated_body = first_gened_func.child_by_field_name("body").text.decode()
    print(f"{generated_body=}")
    altered_code = old_code[:func_start_idx] + identifier_str + generated_body + old_code[body_end_idx:]
    print(f"{altered_code=}") #we get here successfully
    yield altered_code, pipeline #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown
    return altered_code, pipeline #never gets used by the code block? maybe I need to yield it first? but works in the ov_notebook

def add_history(func_id, orig_rtn, gened_rtn, history):
    # is this a list? or a JSON dict?
    history[func_id] =  (orig_rtn, gened_rtn)
    return history, history

def list_dropdown(in_code): #only used for auto update, not on sample pick?
    funcs = _parse_functions(in_code)
    
    # print(f"updating drop down to:{func_identifiers=}")
    func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)]
    # funcs = [n for n in funcs] #wrapped as set to avoid json issues?
    print(f"updating drop down to:{func_identifiers}")
    return funcs, gr.Dropdown.update(choices=func_identifiers)

def construct_embed(source_url):
    shader_id = source_url.split("/")[-1]
    return f'<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/{shader_id}?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>'

with gr.Blocks() as site:
    top_md = gr.Markdown(intro_text)
    model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
    sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=3211, label="pick sample from dataset", step=1.0)
    func_dropdown = gr.Dropdown(value=["0: edit the Code (or load a shader) to update this dropdown"], label="chose a function to modify") #breaks if I add a string in before that? #TODO: use type="index" to get int - always gives None?
    with gr.Accordion("Advanced settings", open=False): # from: https://huggingface.co/spaces/bigcode/bigcode-playground/blob/main/app.py
        with gr.Row():
            column_1, column_2 = gr.Column(), gr.Column()
            with column_1:
                temperature = gr.Slider(
                    label="Temperature",
                    value=0.0, #start out at 0 to do greedy? or will there be an error?
                    minimum=0.0,
                    maximum=1.0,
                    step=0.05,
                    interactive=True,
                    info="Higher values produce more diverse outputs",
                )
                max_new_tokens = gr.Slider(
                    label="Max new tokens",
                    value=160,
                    minimum=0,
                    maximum=2048, #this could be inferred from the model?
                    step=32,
                    interactive=True,
                    info="The maximum numbers of new tokens",
                )
            with column_2:
                top_p = gr.Slider(
                    label="Top-p (nucleus sampling)",
                    value=0.85,
                    minimum=0.0,
                    maximum=1,
                    step=0.05,
                    interactive=True,
                    info="Higher values sample more low-probability tokens",
                )
                repetition_penalty = gr.Slider(
                    label="Repetition penalty",
                    value=1.2,
                    minimum=1.0,
                    maximum=2.0,
                    step=0.05,
                    interactive=True,
                    info="Penalize repeated tokens",
                )
    with gr.Row():
        gen_return_button = gr.Button("generate a alternate return statement", label="generate return")
        gen_func_button = gr.Button("generate an alternate function body", label="generate function")
    with gr.Row():
        with gr.Column():
            source_embed = gr.HTML('<iframe width="640" height="360" frameborder="0" src="" allowfullscreen></iframe>', label="How this shader originally renders")
            our_embed = gr.HTML(label="glsl render of the current code")
        sample_code = gr.Code(new_shadertoy_code, label="Current Code (will update changes you generate)", language=None)
    bot_md = gr.Markdown(outro_text)
    sample_pass = gr.State(value={})
    pipe = gr.State(value=PIPE)
    pipe.value=_make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine") # set a default like this? 
    funcs = gr.State(value=[])
    # funcs.value.append(list_dropdown(sample_code.value)[0]) #to circumvent the json issue?
    # hist_state = gr.State(Value={})
    # history_table = gr.JSON()
    
    model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) # how can we trigger this on load?
    sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, source_embed]) 
    gen_return_button.click(fn=alter_return, inputs=[sample_code, func_dropdown, pipe], outputs=[sample_code])
    gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe]).then(
        fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]
    )
    sample_code.change(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]).then(
        fn=make_iframe, inputs=[sample_code], outputs=[our_embed])
site.queue()
site.launch()