Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import gradio as gr | |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import datasets | |
| import asyncio | |
| import numpy as np | |
| import torch | |
| from threading import Thread | |
| def make_script(shader_code): | |
| # code copied and fixed(escaping single quotes to double quotes!!!) from https://webglfundamentals.org/webgl/webgl-shadertoy.html | |
| script = (""" | |
| <!-- Licensed under a BSD license. See license.html for license --> | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes"> | |
| <title>WebGL - Shadertoy</title> | |
| <link type="text/css" href="https://webglfundamentals.org/webgl/resources/webgl-tutorials.css" rel="stylesheet" /> | |
| <style> | |
| .divcanvas { | |
| position: relative; | |
| display: inline-block; | |
| } | |
| canvas { | |
| display: block; | |
| } | |
| .playpause { | |
| position: absolute; | |
| left: 10px; | |
| top: 10px; | |
| width: 100%; | |
| height: 100%; | |
| font-size: 60px; | |
| justify-content: center; | |
| align-items: center; | |
| color: rgba(255, 255, 255, 0.3); | |
| transition: opacity 0.2s ease-in-out; | |
| } | |
| .playpausehide, | |
| .playpause:hover { | |
| opacity: 0; | |
| } | |
| .iframe .divcanvas { | |
| display: block; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <div class="divcanvas"> | |
| <canvas id="canvas"></canvas> | |
| <div class="playpause">▶</div> | |
| </div> | |
| \nblank canvas here indicates that some of the shadertoy specific functions are not yet supported with this implementation (like #define I believe). you can always copy and paste the code into a shadertoy.com window to try. | |
| </body> | |
| <!-- | |
| for most samples webgl-utils only provides shader compiling/linking and | |
| canvas resizing because why clutter the examples with code thats the same in every sample. | |
| See https://webglfundamentals.org/webgl/lessons/webgl-boilerplate.html | |
| and https://webglfundamentals.org/webgl/lessons/webgl-resizing-the-canvas.html | |
| for webgl-utils, m3, m4, and webgl-lessons-ui. | |
| --> | |
| <script src="https://webglfundamentals.org/webgl/resources/webgl-utils.js"></script> | |
| <script> | |
| "use strict"; | |
| function main() { | |
| // Get A WebGL context | |
| /** @type {HTMLCanvasElement} */ | |
| const canvas = document.querySelector("#canvas"); | |
| const gl = canvas.getContext("webgl"); | |
| if (!gl) { | |
| return; | |
| } | |
| const vs = ` | |
| // an attribute will receive data from a buffer | |
| attribute vec4 a_position; | |
| // all shaders have a main function | |
| void main() { | |
| // gl_Position is a special variable a vertex shader | |
| // is responsible for setting | |
| gl_Position = a_position; | |
| } | |
| `; | |
| const fs = ` | |
| precision highp float; | |
| uniform vec2 iResolution; | |
| uniform vec2 iMouse; | |
| uniform float iTime; | |
| """ + shader_code + """ | |
| void main() { | |
| mainImage(gl_FragColor, gl_FragCoord.xy); | |
| } | |
| `; | |
| // setup GLSL program | |
| const program = webglUtils.createProgramFromSources(gl, [vs, fs]); | |
| // look up where the vertex data needs to go. | |
| const positionAttributeLocation = gl.getAttribLocation(program, "a_position"); | |
| // look up uniform locations | |
| const resolutionLocation = gl.getUniformLocation(program, "iResolution"); | |
| const mouseLocation = gl.getUniformLocation(program, "iMouse"); | |
| const timeLocation = gl.getUniformLocation(program, "iTime"); | |
| // Create a buffer to put three 2d clip space points in | |
| const positionBuffer = gl.createBuffer(); | |
| // Bind it to ARRAY_BUFFER (think of it as ARRAY_BUFFER = positionBuffer) | |
| gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer); | |
| // fill it with a 2 triangles that cover clipspace | |
| gl.bufferData(gl.ARRAY_BUFFER, new Float32Array([ | |
| -1, -1, // first triangle | |
| 1, -1, | |
| -1, 1, | |
| -1, 1, // second triangle | |
| 1, -1, | |
| 1, 1, | |
| ]), gl.STATIC_DRAW); | |
| const playpauseElem = document.querySelector(".playpause"); | |
| const inputElem = document.querySelector(".divcanvas"); | |
| inputElem.addEventListener("mouseover", requestFrame); | |
| inputElem.addEventListener("mouseout", cancelFrame); | |
| let mouseX = 0; | |
| let mouseY = 0; | |
| function setMousePosition(e) { | |
| const rect = inputElem.getBoundingClientRect(); | |
| mouseX = e.clientX - rect.left; | |
| mouseY = rect.height - (e.clientY - rect.top) - 1; // bottom is 0 in WebGL | |
| } | |
| inputElem.addEventListener("mousemove", setMousePosition); | |
| inputElem.addEventListener("touchstart", (e) => { | |
| e.preventDefault(); | |
| playpauseElem.classList.add("playpausehide"); | |
| requestFrame(); | |
| }, {passive: false}); | |
| inputElem.addEventListener("touchmove", (e) => { | |
| e.preventDefault(); | |
| setMousePosition(e.touches[0]); | |
| }, {passive: false}); | |
| inputElem.addEventListener("touchend", (e) => { | |
| e.preventDefault(); | |
| playpauseElem.classList.remove("playpausehide"); | |
| cancelFrame(); | |
| }, {passive: false}); | |
| let requestId; | |
| function requestFrame() { | |
| if (!requestId) { | |
| requestId = requestAnimationFrame(render); | |
| } | |
| } | |
| function cancelFrame() { | |
| if (requestId) { | |
| cancelAnimationFrame(requestId); | |
| requestId = undefined; | |
| } | |
| } | |
| let then = 0; | |
| let time = 0; | |
| function render(now) { | |
| requestId = undefined; | |
| now *= 0.001; // convert to seconds | |
| const elapsedTime = Math.min(now - then, 0.1); | |
| time += elapsedTime; | |
| then = now; | |
| webglUtils.resizeCanvasToDisplaySize(gl.canvas); | |
| // Tell WebGL how to convert from clip space to pixels | |
| gl.viewport(0, 0, gl.canvas.width, gl.canvas.height); | |
| // Tell it to use our program (pair of shaders) | |
| gl.useProgram(program); | |
| // Turn on the attribute | |
| gl.enableVertexAttribArray(positionAttributeLocation); | |
| // Bind the position buffer. | |
| gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer); | |
| // Tell the attribute how to get data out of positionBuffer (ARRAY_BUFFER) | |
| gl.vertexAttribPointer( | |
| positionAttributeLocation, | |
| 2, // 2 components per iteration | |
| gl.FLOAT, // the data is 32bit floats | |
| false, // dont normalize the data | |
| 0, // 0 = move forward size * sizeof(type) each iteration to get the next position | |
| 0, // start at the beginning of the buffer | |
| ); | |
| gl.uniform2f(resolutionLocation, gl.canvas.width, gl.canvas.height); | |
| gl.uniform2f(mouseLocation, mouseX, mouseY); | |
| gl.uniform1f(timeLocation, time); | |
| gl.drawArrays( | |
| gl.TRIANGLES, | |
| 0, // offset | |
| 6, // num vertices to process | |
| ); | |
| requestFrame(); | |
| } | |
| requestFrame(); | |
| requestAnimationFrame(cancelFrame); | |
| } | |
| main(); | |
| </script> | |
| </html> | |
| """) | |
| return script | |
| def make_iframe(shader_code): #keep a single function? | |
| script = make_script(shader_code) | |
| return f"""<iframe width="640" height="420" srcdoc=\'{script}\' allowfullscreen></iframe>""" | |
| intro_text = """ | |
| # Welcome to the interactive shadercoding demo. | |
| This gives you access to a filtered version of the [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys) dataset, only shaders that consist of a single pass are available. | |
| And then lets you use code generation models to make alterations to part of the shadercode. | |
| ## How To Use: | |
| 1. Load any Model for [`text-generation`](https://huggingface.co/models?pipeline_tag=text-generation) and hit ENTER. | |
| 2. Use the slider to sample a shader from the dataset. | |
| - The original shader will be embedding on the left, click on title to get to the source. | |
| - The shadercode will be displayed on the right, this is interactive. | |
| - A preview of the currently displayed shadercode will be displayed on the lower left. (hover to advance time) | |
| 3. use the dropdown to select a function to modify. | |
| 4. press either button to make modifications to that function | |
| 5. you can also edit the code manually. | |
| """ | |
| outro_text =""" | |
| ## Models to try (look at [ShaderEval](https://huggingface.co/spaces/Vipitis/ShaderEval) for an indication of how helpful they will be): | |
| - [gpt2](https://huggingface.co/gpt2) baseline for language models, really struggles with shadercode. | |
| - [bigscience/bloom-1b1](https://huggingface.co/bigscience/bloom-1b1) a newer and larger freely available model. Does understand a big of code. | |
| - [codeparrot/codeparrot-small](https://huggingface.co/codeparrot/codeparrot-small) a model trained on code, but not on shadercode. Manages to graps the patterns. | |
| - [salesforce/codegen-2B-multi](https://huggingface.co/salesforce/codegen-2B-multi) a larger model that indicates some potential. | |
| - [bigcode/santacoder](https://huggingface.co/bigcode/santacoder) a model trained on subset of [TheStack](https://huggingface.co/datasets/bigcode/the-stack), struggles with shadercode. | |
| - [Vipitis/santacoder-finetuned-the-stack-glsl](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl) fine-tuned by me on the glsl subset of [TheStack](https://huggingface.co/datasets/bigcode/the-stack), is an improvement. | |
| - [Vipitis/santacoder-finetuned-Shadertoys](https://huggingface.co/Vipitis/santacoder-finetuned-Shadertoys) fine-tuned by me on whole shaders from [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys). Does overfit quite a bit with greedy decoding. | |
| - [Vipitis/santacoder-finetuned-Shadertoys-fine](https://huggingface.co/Vipitis/santacoder-finetuned-Shadertoys-fine) fine-tuned by me just functions from [Shadertoys-fine](https://huggingface.co/datasets/Vipitis/Shadertoys-fine). Memorizes the exact function about half the time. | |
| - [bigcode/starcoder](https://huggingface.co/bigcode/starcoder) a very large model which I haven't tried yet. | |
| - **any other model you want to** | |
| ## TODO (feel free to contribute with a [Pull-Request](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl/discussions?status=open&type=pull_request)): | |
| - [x] use embedded Shadertoy for reference/attribution (done, but some errors) | |
| - [~] working render implementation on CPU only space (as webgl via webglfundamentals, ccs needs fixing for iframe (or hijack Shadertoy iframe)) | |
| - [~] generate variations of return statements [ShaderEval task1](https://huggingface.co/spaces/Vipitis/ShaderEval) (needs to be reworked using the other parts) | |
| - [x] generate whole functions (seems to work quite well) | |
| - [] dropdown for model selection (from curated list or all supported models?) | |
| - [] generation history stating which function and orig/generated returns. (use State ??). do it as comments in the code? | |
| - [~] display errros/issues to the user (raise gr.Error could be one idea, but highlighting in the code would be awesome) currently adds a comment to the code. | |
| - [~] generate whole shaders (via prompts guidance, recursive from errors) - prompt context is in progress. | |
| - [x] accordion with generation parameters (as pipeline_kwargs?) look up starcoder playround and take "inspiration" from there (implemented for both buttons, untested) | |
| - [] support FIM task for better model context | |
| - [x] include some context for prompt (title, comments before a functions) - now takes all comments directly before a function as well as all comments at the beginning inside a function. (misses comments between argument list and body) | |
| - [] gradio examples | |
| - [] use GPU if available, respect memory restrictions. | |
| - [x] stream model generation (maybe in a new window?) - janky solution and only sometimes hangs up | |
| - [] 2nd iFrame needs a lot of fixing (I am not a web developer, need help) BUG:background is white, so colors are wrong. Shadertoy uses black background (or we ignore alpha). | |
| - [] (optional) filtering the dataset by license? | |
| ### Notes: | |
| - this is meant as a resource to show code generation for a "creative" task. | |
| - the goal is not to not replace shader artists, but aims to be an assistant instead. | |
| - the space still lacks quite a lot of features, but will continue to evolve. | |
| - this demo can be useful to sannity check evaluation results, where the academic numbers are made. | |
| - If you create a remix with these tools, please attribute the original creator of your starting point when sharing the results. (And perhaps share in the [discussion tab](https://huggingface.co/Vipitis/santacoder-finetuned-the-stack-glsl/discussions?status=open&type=discussion) too) | |
| """ | |
| new_shadertoy_code = """void mainImage( out vec4 fragColor, in vec2 fragCoord ) | |
| { | |
| // touch the slider to load a shader from the dataset or start coding from here. | |
| vec2 uv = fragCoord/iResolution.xy; | |
| vec3 col = 0.5 + 0.5*cos(iTime+uv.xyx+vec3(0,2,4)); | |
| fragColor = vec4(col,1.0); | |
| }""" | |
| passes_dataset = datasets.load_dataset("Vipitis/Shadertoys") | |
| single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions. | |
| # single_passes = single_passes.filter(lambda x: x["license"] not in "copyright") #to avoid any "do not display this" license? | |
| all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]]) | |
| num_samples = len(all_single_passes) | |
| import tree_sitter | |
| from tree_sitter import Language, Parser | |
| Language.build_library("./build/my-languages.so", ['tree-sitter-glsl']) | |
| GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl') | |
| parser = Parser() | |
| parser.set_language(GLSL_LANGUAGE) | |
| def grab_sample(sample_idx): | |
| sample_pass = all_single_passes[sample_idx] | |
| sample_code = sample_pass["code"] | |
| sample_source = sample_pass["source"] | |
| sample_title = sample_pass["title"] | |
| sample_auhtor = sample_pass["author"] | |
| source_iframe = construct_embed(sample_source) | |
| print(f"{source_iframe=}") | |
| # sample_funcs = _parse_functions(sample_code) | |
| # funcs = _parse_functions(sample_code) | |
| # func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)] | |
| # print(f"updating drop down to:{func_identifiers}") | |
| return sample_pass, sample_code, sample_title, source_iframe, funcs#, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor | |
| def _parse_functions(in_code): | |
| """ | |
| returns all functions in the code as their actual nodes. | |
| includes any comment made directly after the function definition or diretly after #copilot trigger | |
| """ | |
| tree = parser.parse(bytes(in_code, "utf8")) | |
| funcs = [n for n in tree.root_node.children if n.type == "function_definition"] | |
| return funcs | |
| PIPE = None | |
| def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing | |
| # if torch.cuda.is_available(): | |
| # device = "cuda" | |
| # else: | |
| # device = "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True) | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True) #, device=device) | |
| PIPE = pipe # set the global? | |
| print(f"loaded model {model_cp} as a pipline") | |
| return pipe | |
| def _run_generation(model_ctx:str, pipe, gen_kwargs:dict): | |
| """ | |
| Text generation function | |
| Args: | |
| model_ctx (str): The context to start generation from. | |
| pipe (Pipeline): The pipeline to use for generation. | |
| gen_kwargs (dict): The generation kwargs. | |
| Returns: | |
| str: The generated text. (it iterates over time) | |
| """ | |
| # Tokenize the model_context | |
| model_inputs = pipe.tokenizer(model_ctx, return_tensors="pt") | |
| # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer | |
| # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread. | |
| streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15.0) | |
| generate_kwargs = dict(model_inputs, streamer=streamer, **gen_kwargs) | |
| t = Thread(target=pipe.model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| # Pull the generated text from the streamer, and update the model output. | |
| model_output = "" | |
| for new_text in streamer: | |
| # print("step", end="") | |
| model_output += new_text | |
| yield model_output | |
| streamer.on_finalized_text("stream reached the end.") | |
| return model_output #is this ever reached? | |
| def process_retn(retn): | |
| return retn.split(";")[0].strip() | |
| def get_full_replacement(orig_code, retn_start_idx, retn_end_idx, prediction) -> str: | |
| """ | |
| Batches the generated return statement into the code and returns the full altered code. | |
| """ | |
| print(f"{orig_code[retn_start_idx:retn_end_idx]=}") | |
| generated = process_retn(prediction) | |
| print(f"{generated=}") | |
| variation = orig_code[:retn_start_idx] + generated + orig_code[retn_end_idx:] | |
| return variation | |
| def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE): #default pipeline can't be passed as gloabl? | |
| """ | |
| Replaces the return statement of a function with a generated one. | |
| Args: | |
| orig_code (str): The original code. | |
| func_idx (int): The index of the function to replace the return statement of. | |
| temperature (float): The temperature to use for generation. | |
| max_new_tokens (int): The maximum number of tokens to generate. | |
| top_p (float): The top_p to use for generation. | |
| repetition_penalty (float): The repetition_penalty to use for generation. | |
| pipeline (Pipeline): The pipeline to use for generation. | |
| Returns: | |
| str: The altered code. | |
| """ | |
| if pipeline is None: | |
| print("no pipeline found, loading default one") | |
| pipeline = _make_pipeline() | |
| if isinstance(func_idx, str): | |
| print(f"{func_idx=}") | |
| func_idx = int(func_idx.split(":")[0].strip()) | |
| elif isinstance(func_idx, int): | |
| pass | |
| else: | |
| raise gr.Error(f"func_idx must be int or str, not {type(func_idx)}") | |
| generation_kwargs = _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty) | |
| retrns = [] | |
| retrn_start_idx = orig_code.find("return") | |
| while retrn_start_idx != -1: | |
| retrn_end_idx = orig_code.find(";", retrn_start_idx) | |
| retrns.append((retrn_start_idx, retrn_end_idx)) | |
| retrn_start_idx = orig_code.find("return", retrn_end_idx) | |
| num_returns = len(retrns) | |
| if num_returns == 0: | |
| print("no return statement found, returning original code") | |
| return orig_code | |
| func_idx = int(max(0, min(func_idx, num_returns - 1))) #clamp to valid range, cast to int as a bodge. | |
| retrn_start_idx, retrn_end_idx = retrns[func_idx] | |
| model_context = orig_code[:retrn_start_idx] #TODO: maximal context? | |
| model_inp = model_context + "return" | |
| pipe_generation = pipeline(model_inp, return_full_text=False, **generation_kwargs)[0]["generated_text"] #pipeline kwargs are missing?! | |
| altered_code = get_full_replacement(orig_code, retrn_start_idx+7, retrn_end_idx, pipe_generation) | |
| return altered_code | |
| def _line_chr2char(text, line_idx, chr_idx): | |
| """ | |
| returns the character index at the given line and character index. | |
| """ | |
| lines = text.split("\n") | |
| char_idx = 0 | |
| for i in range(line_idx): | |
| char_idx += len(lines[i]) + 1 | |
| char_idx += chr_idx | |
| return char_idx | |
| def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty): | |
| gen_kwargs = {} | |
| gen_kwargs["temperature"] = temperature | |
| gen_kwargs["max_new_tokens"] = max_new_tokens | |
| gen_kwargs["top_p"] = top_p | |
| gen_kwargs["repetition_penalty"] = repetition_penalty | |
| return gen_kwargs | |
| def _grab_before_comments(func_node): | |
| """ | |
| returns the comments that happen just before a function node | |
| """ | |
| precomment = "" | |
| last_comment_line = 0 | |
| for node in func_node.parent.children: #could you optimize where to iterated from? directon? | |
| if node.start_point[0] != last_comment_line + 1: | |
| precomment = "" | |
| if node.type == "comment": | |
| precomment += node.text.decode() + "\n" | |
| last_comment_line = node.start_point[0] | |
| elif node == func_node: | |
| return precomment | |
| return precomment | |
| def _get_docstrings(func_node): | |
| """ | |
| returns the docstring of a function node | |
| """ | |
| docstring = "" | |
| for node in func_node.child_by_field_name("body").children: | |
| if node.type == "comment" or node.type == "{": | |
| docstring += node.text.decode() + "\n" | |
| else: | |
| return docstring | |
| return docstring | |
| def alter_body(old_code, func_id, funcs_list: list, prompt, temperature, max_new_tokens, top_p, repetition_penalty, pipeline=PIPE): | |
| """ | |
| Replaces the body of a function with a generated one. | |
| Args: | |
| old_code (str): The original code. | |
| func_node (Node): The node of the function to replace the body of. | |
| funcs_list (list): The list of all functions in the code. | |
| prompt (str): The prompt(title) to use for generation. | |
| temperature (float): The temperature to use for generation. | |
| max_new_tokens (int): The maximum number of tokens to generate. | |
| top_p (float): The top_p to use for generation. | |
| repetition_penalty (float): The repetition_penalty to use for generation. | |
| pipeline (Pipeline): The pipeline to use for generation. | |
| Returns: | |
| str: The altered code. | |
| pipeline (Pipeline): The pipeline to update the state | |
| """ | |
| if isinstance(func_id, str): | |
| print(f"{func_id=}") | |
| func_id = int(func_id.split(":")[0].strip()) #undo their string casting? | |
| elif isinstance(func_id, int): | |
| pass | |
| else: | |
| raise gr.Error(f"func_id must be int or str, not {type(func_id)}") | |
| func_node = funcs_list[func_id] | |
| print(f"using for generation: {func_node=}") | |
| generation_kwargs = _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty) | |
| print(f"{pipeline=}") # check if default even loaded | |
| if pipeline is None: | |
| print("no pipeline found, loading default one") | |
| pipeline = _make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine") | |
| func_start_idx = _line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1]) | |
| identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx? | |
| body_node = func_node.child_by_field_name("body") | |
| body_start_idx = _line_chr2char(old_code, body_node.start_point[0], body_node.start_point[1]) | |
| body_end_idx = _line_chr2char(old_code, body_node.end_point[0], body_node.end_point[1]) | |
| print(f"{old_code[body_start_idx:body_end_idx]=}") | |
| model_context = identifier_str # base case | |
| # add any comments at the beginning of the function to the model_context | |
| # second_child = func_node.child_by_field_name("body").children[1] #might error out? | |
| docstring = _get_docstrings(func_node) #might be empty? | |
| if docstring: | |
| model_context = model_context + "\n" + docstring | |
| model_context = _grab_before_comments(func_node) + model_context #prepend comments | |
| if prompt != "": | |
| model_context = f"//avialable functions: {','.join([n.child_by_field_name('declarator').text.decode() for n in funcs_list])}\n" + model_context #prepend available functions | |
| model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title | |
| model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint | |
| print(f"{model_context=}") | |
| # generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"] | |
| generation = _run_generation(model_context, pipeline, generation_kwargs) | |
| for i in generation: | |
| print(f"{i=}") | |
| yield model_context + i, pipeline #fix in between, do all the stuff in the end? | |
| generation = i[:] #seems to work | |
| print(f"{generation=}") | |
| ctx_with_generation = model_context + generation | |
| print(f"{ctx_with_generation=}") | |
| try: | |
| #strip the body | |
| first_gened_func = _parse_functions(ctx_with_generation)[0] # truncate generation to a single function? | |
| except IndexError: | |
| print("generation wasn't a full function.") | |
| altered_code = old_code[:func_start_idx] + model_context + generation + "//the generation didn't complete the function!\n" + old_code[body_end_idx:] #needs a newline to break out of the comment. | |
| return altered_code, pipeline | |
| # raise gr.Error(f"didn't generate a full function: {generation!r}]") | |
| print(f"{first_gened_func=}") | |
| generated_body = first_gened_func.child_by_field_name("body").text.decode() | |
| print(f"{generated_body=}") | |
| altered_code = old_code[:func_start_idx] + identifier_str + generated_body + old_code[body_end_idx:] | |
| print(f"{altered_code=}") #we get here successfully | |
| yield altered_code, pipeline #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown | |
| return altered_code, pipeline #never gets used by the code block? maybe I need to yield it first? but works in the ov_notebook | |
| def add_history(func_id, orig_rtn, gened_rtn, history): | |
| # is this a list? or a JSON dict? | |
| history[func_id] = (orig_rtn, gened_rtn) | |
| return history, history | |
| def list_dropdown(in_code): #only used for auto update, not on sample pick? | |
| funcs = _parse_functions(in_code) | |
| # print(f"updating drop down to:{func_identifiers=}") | |
| func_identifiers = [f"{idx:2d}: {n.child_by_field_name('declarator').text.decode()}" for idx, n in enumerate(funcs)] | |
| # funcs = [n for n in funcs] #wrapped as set to avoid json issues? | |
| print(f"updating drop down to:{func_identifiers}") | |
| return funcs, gr.Dropdown.update(choices=func_identifiers) | |
| def construct_embed(source_url): | |
| shader_id = source_url.split("/")[-1] | |
| return f'<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/{shader_id}?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>' | |
| with gr.Blocks() as site: | |
| top_md = gr.Markdown(intro_text) | |
| model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True) | |
| sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=3211, label="pick sample from dataset", step=1.0) | |
| func_dropdown = gr.Dropdown(value=["0: edit the Code (or load a shader) to update this dropdown"], label="chose a function to modify") #breaks if I add a string in before that? #TODO: use type="index" to get int - always gives None? | |
| prompt_text = gr.Textbox(value="the title used by the model has generation hint", label="prompt text", info="leave blank to skip", interactive=True) | |
| with gr.Accordion("Advanced settings", open=False): # from: https://huggingface.co/spaces/bigcode/bigcode-playground/blob/main/app.py | |
| with gr.Row(): | |
| column_1, column_2 = gr.Column(), gr.Column() | |
| with column_1: | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| value=0.2, #start out at 0 to do greedy? or will there be an error? | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ) | |
| max_new_tokens = gr.Slider( | |
| label="Max new tokens", | |
| value=265, | |
| minimum=0, | |
| maximum=2048, #this could be inferred from the model? | |
| step=32, | |
| interactive=True, | |
| info="The maximum numbers of new tokens", | |
| ) | |
| with column_2: | |
| top_p = gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=0.90, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values sample more low-probability tokens", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="Repetition penalty", | |
| value=1.2, | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Penalize repeated tokens", | |
| ) | |
| with gr.Row(): | |
| gen_return_button = gr.Button("generate a alternate return statement", label="generate return", scale=0) | |
| gen_func_button = gr.Button("generate an alternate function body", label="generate function", scale=1) | |
| with gr.Row(): | |
| with gr.Column(): | |
| source_embed = gr.HTML('<iframe width="640" height="360" frameborder="0" src="" allowfullscreen></iframe>', label="How this shader originally renders") | |
| our_embed = gr.HTML(label="glsl render of the current code") | |
| sample_code = gr.Code(new_shadertoy_code, label="Current Code (will update changes you generate)", language=None) | |
| bot_md = gr.Markdown(outro_text) | |
| sample_pass = gr.State(value={}) | |
| pipe = gr.State(value=PIPE) | |
| pipe.value=_make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine") # set a default like this? | |
| funcs = gr.State(value=[]) | |
| # funcs.value.append(list_dropdown(sample_code.value)[0]) #to circumvent the json issue? | |
| # hist_state = gr.State(Value={}) | |
| # history_table = gr.JSON() | |
| model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe]) # how can we trigger this on load? | |
| sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, prompt_text, source_embed]) #funcs here? | |
| gen_return_button.click(fn=alter_return, inputs=[sample_code, func_dropdown, pipe], outputs=[sample_code]) | |
| gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, prompt_text, temperature, max_new_tokens, top_p, repetition_penalty, pipe], outputs=[sample_code, pipe]).then( | |
| fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown] | |
| ) | |
| sample_code.change(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]).then( | |
| fn=make_iframe, inputs=[sample_code], outputs=[our_embed]) | |
| if __name__ == "__main__": #works on huggingface? | |
| site.queue() | |
| site.launch() | 
 
			
