import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from wgpu.utils.shadertoy import * 
from wgpu.gui.offscreen import WgpuCanvas as OffscreenCanvas, run as run_offscreen
import wgpu
import time
import ctypes
import datasets
from PIL import Image
import asyncio
import numpy as np

# reimplement the Shadertoy class with offscreen canvas!
class ShadertoyCustom(Shadertoy):
    def __init__(self, shader_code, resolution=(800, 450), canvas_class=WgpuCanvas, run_fn=run):
        self._canvas_class = canvas_class
        self._fun_fn = run_fn
        super().__init__(shader_code, resolution)
        self._uniform_data = UniformArray(
            ("mouse", "f", 4),
            ("resolution", "f", 3),
            ("time", "f", 1),
            ("time_delta", "f", 1),
            ("frame", "I", 1),
        )
        
        self._shader_code = shader_code
        self._uniform_data["resolution"] = resolution + (1,)

        self._prepare_render()
        self._bind_events()
    
    def _prepare_render(self):
        import wgpu.backends.rs  # noqa

        self._canvas = self._canvas_class(title="Shadertoy", size=self.resolution, max_fps=60)

        adapter = wgpu.request_adapter(
            canvas=self._canvas, power_preference="high-performance"
        )
        self._device = adapter.request_device()

        self._present_context = self._canvas.get_context()

        # We use "bgra8unorm" not "bgra8unorm-srgb" here because we want to let the shader fully control the color-space.
        self._present_context.configure(
            device=self._device, format=wgpu.TextureFormat.bgra8unorm
        )

        shader_type = self.shader_type
        if shader_type == "glsl":
            vertex_shader_code = vertex_code_glsl
            frag_shader_code = (
                builtin_variables_glsl + self.shader_code + fragment_code_glsl
            )
        elif shader_type == "wgsl":
            vertex_shader_code = vertex_code_wgsl
            frag_shader_code = (
                builtin_variables_wgsl + self.shader_code + fragment_code_wgsl
            )

        vertex_shader_program = self._device.create_shader_module(
            label="triangle_vert", code=vertex_shader_code
        )
        frag_shader_program = self._device.create_shader_module(
            label="triangle_frag", code=frag_shader_code
        )

        self._uniform_buffer = self._device.create_buffer(
            size=self._uniform_data.nbytes,
            usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST,
        )

        bind_group_layout = self._device.create_bind_group_layout(
            entries=binding_layout
        )

        self._bind_group = self._device.create_bind_group(
            layout=bind_group_layout,
            entries=[
                {
                    "binding": 0,
                    "resource": {
                        "buffer": self._uniform_buffer,
                        "offset": 0,
                        "size": self._uniform_data.nbytes,
                    },
                },
            ],
        )

        self._render_pipeline = self._device.create_render_pipeline(
            layout=self._device.create_pipeline_layout(
                bind_group_layouts=[bind_group_layout]
            ),
            vertex={
                "module": vertex_shader_program,
                "entry_point": "main",
                "buffers": [],
            },
            primitive={
                "topology": wgpu.PrimitiveTopology.triangle_list,
                "front_face": wgpu.FrontFace.ccw,
                "cull_mode": wgpu.CullMode.none,
            },
            depth_stencil=None,
            multisample=None,
            fragment={
                "module": frag_shader_program,
                "entry_point": "main",
                "targets": [
                    {
                        "format": wgpu.TextureFormat.bgra8unorm,
                        "blend": {
                            "color": (
                                wgpu.BlendFactor.one,
                                wgpu.BlendFactor.zero,
                                wgpu.BlendOperation.add,
                            ),
                            "alpha": (
                                wgpu.BlendFactor.one,
                                wgpu.BlendFactor.zero,
                                wgpu.BlendOperation.add,
                            ),
                        },
                    },
                ],
            },
        )
    
    def show(self, time: float = 0.0):
        self._canvas.request_draw(self._draw_frame)
        self._fun_fn()

text = """
# Welcome to the interactive shadercoding demo.
## (WIP), you can try and explore the dataset a bit right now. (frames are rendered on the fly, not part of the dataset(yet))

This gives you access to a filtered version of the [Shadertoys](https://huggingface.co/datasets/Vipitis/Shadertoys) dataset, only shaders that const of a single pass (and have at least one fuction with a return statement) are available.
In the near future there will be some buttons and sliders to generate variations of the shadercode itself, and hence get some different images. 
If I find an efficient way, the shaders might run in real time and be interactive.

## TODO:
 - [x] use embedded Shadertoy for reference/attribution (done, but some errors)
 - [] working render implementation on CPU only space (use the browser for WebGPU?, maybe via an iFrame too?)
 - [~] generate variations of return statements [ShaderEval task1](https://huggingface.co/spaces/Vipitis/ShaderEval) (missing all of the generation parameters)
 - [] generation history stating which function and orig/generated returns. (use State ??). do it as comments in the code?
 - [x?] generate whole functions (only works once)
 - [] generate whole shaders (via prompts?)
"""
passes_dataset = datasets.load_dataset("Vipitis/Shadertoys")
single_passes = passes_dataset.filter(lambda x: not x["has_inputs"] and x["num_passes"] == 1) #could also include shaders with no extra functions.
all_single_passes = datasets.concatenate_datasets([single_passes["train"], single_passes["test"]])
num_samples = len(all_single_passes)    

import tree_sitter
from tree_sitter import Language, Parser
Language.build_library("./build/my-languages.so", ['tree-sitter-glsl'])
GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl')
parser = Parser()
parser.set_language(GLSL_LANGUAGE)


async def get_image(code, time= 0.0, resolution=(512, 420)):
    tree = parser.parse(bytes(code, "utf8"))
    if tree.root_node.has_error:
        print("ERROR in the tree, aborting.")
        return None
    shader = ShadertoyCustom(code, resolution, OffscreenCanvas, run_offscreen) #pass offscreen canvas here.
    shader._uniform_data["time"] = time #set any time you want
    shader._canvas.request_draw(shader._draw_frame)
    # frame = shader._canvas.snapshot().data
    frame = np.asarray(shader._canvas.draw())
    img = Image.fromarray(frame)
    # remove transparent pixels 
    img = img.convert('RGB')
    return img

def grab_sample(sample_idx):
    sample_pass = all_single_passes[sample_idx]
    sample_code = sample_pass["code"]
    sample_source = sample_pass["source"]
    sample_title = sample_pass["title"]
    sample_auhtor = sample_pass["author"]
    source_iframe = construct_embed(sample_source)
    print(f"{source_iframe=}")
    sample_funcs = _parse_functions(sample_code)
    funcs = _parse_functions(sample_code)
    func_identifiers = [(idx,n.child_by_field_name("declarator").text.decode()) for idx, n in enumerate(funcs)]
    print(f"updating drop down to:{func_identifiers}")
    return sample_pass, sample_code, source_iframe, funcs, gr.Dropdown.update(choices=func_identifiers) #, sample_title, sample_auhtor


def _parse_functions(in_code):
    """
    returns all functions in the code as their actual nodes.

    """
    tree = parser.parse(bytes(in_code, "utf8"))
    funcs = [n for n in tree.root_node.children if n.type == "function_definition"]
    return funcs

PIPE = None

def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #bad default model for testing
    tokenizer = AutoTokenizer.from_pretrained(model_cp, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_cp, trust_remote_code=True)
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, trust_remote_code=True)
    PIPE = pipe # set the global?
    print(f"loaded model {model_cp} as a pipline")
    return pipe


def process_retn(retn):
    return retn.split(";")[0].strip()

def get_full_replacement(orig_code, retn_start_idx, retn_end_idx, prediction) -> str:
    """
    Batches the generated return statement into the code and returns the full altered code.
    """
    print(f"{orig_code[retn_start_idx:retn_end_idx]=}")
    generated = process_retn(prediction)
    print(f"{generated=}")
    variation = orig_code[:retn_start_idx] + generated + orig_code[retn_end_idx:]
    return variation

def alter_return(orig_code, func_idx=0, pipeline=PIPE): #default pipeline can't be passed as gloabl?
    """
    Replaces the return statement of a function with a generated one.
    Args:
        orig_code (str): The original code.
        func_idx (int): The index of the function to replace the return statement of.
        pipeline (Pipeline): The pipeline to use for generation.
    Returns:
        str: The altered code.
    """
    if pipeline is None:
        print("no pipeline found, loading default one")
        pipeline = _make_pipeline()

    retrns = []
    retrn_start_idx = orig_code.find("return")
    while retrn_start_idx != -1:
        retrn_end_idx = orig_code.find(";", retrn_start_idx)
        retrns.append((retrn_start_idx, retrn_end_idx))
        retrn_start_idx = orig_code.find("return", retrn_end_idx)
    num_returns = len(retrns)
    if num_returns == 0:
        print("no return statement found, returning original code")
        return orig_code
    func_idx = int(max(0, min(func_idx, num_returns - 1))) #clamp to valid range, cast to int as a bodge.
    retrn_start_idx, retrn_end_idx = retrns[func_idx]
    model_context = orig_code[:retrn_start_idx] #TODO: maximal context?
    model_inp = model_context + "return"
    new_toks = (retrn_end_idx - retrn_start_idx) * 2 #TODO: approximation, we do have early stopping? maybe also use a number instead?
    pipe_generation = pipeline(model_inp, max_new_tokens=new_toks, return_full_text=False)[0]["generated_text"] #pipeline kwargs are missing?!
    altered_code = get_full_replacement(orig_code, retrn_start_idx+7, retrn_end_idx, pipe_generation)
    
    return altered_code

def _line_chr2char(text, line_idx, chr_idx):
    """
    returns the character index at the given line and character index.
    """
    lines = text.split("\n")
    char_idx = 0
    for i in range(line_idx):
        char_idx += len(lines[i]) + 1
    char_idx += chr_idx
    return char_idx

def alter_body(old_code, func_id, funcs_list, pipeline=PIPE):
    """
    Replaces the body of a function with a generated one.
    Args:
        old_code (str): The original code.
        func_node (Node): The node of the function to replace the body of.
        pipeline (Pipeline): The pipeline to use for generation.
    Returns:
        str: The altered code.
    """
    print(f"{func_id=}")
    func_id = int(func_id.split(",")[0]) #undo their string casting?
    func_node = funcs_list[func_id]
    print(f"using for generation: {func_node=}")
    
    
    
    if pipeline is None:
        print("no pipeline found, loading default one")
        pipeline = _make_pipeline("Vipitis/santacoder-finetuned-Shadertoys-fine")

    func_start_idx = _line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1])
    body_node = func_node.child_by_field_name("body")
    body_start_idx = _line_chr2char(old_code, body_node.start_point[0], body_node.start_point[1])
    body_end_idx = _line_chr2char(old_code, body_node.end_point[0], body_node.end_point[1])
    print(f"{old_code[body_start_idx:body_end_idx]=}")
    model_context = old_code[:body_start_idx]
    generation = pipeline(model_context, max_new_tokens=(body_end_idx - body_start_idx)*2, return_full_text=False)[0]["generated_text"]
    print(f"{generation=}")
    first_gened_func = _parse_functions(old_code[func_start_idx:body_start_idx] + generation)[0] # truncate generation to a single function?
    # strip just the body.
    generated_body = first_gened_func.child_by_field_name("body").text.decode()
    print(f"{generated_body=}")
    altered_code = old_code[:body_start_idx] + generated_body + old_code[body_end_idx:]
    return altered_code

def add_history(func_id, orig_rtn, gened_rtn, history):
    # is this a list? or a JSON dict?
    history[func_id] =  (orig_rtn, gened_rtn)
    return history, history

def list_dropdown(in_code):
    funcs = _parse_functions(in_code)
    
    # print(f"updating drop down to:{func_identifiers=}")
    func_identifiers = [(idx,n.child_by_field_name("declarator").text.decode()) for idx, n in enumerate(funcs)]
    # funcs = [n for n in funcs] #wrapped as set to avoid json issues?
    print(f"updating drop down to:{func_identifiers}")
    return funcs, gr.Dropdown.update(choices=func_identifiers)

def construct_embed(source_url):
    shader_id = source_url.split("/")[-1]
    return f'<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/{shader_id}?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>'

with gr.Blocks() as site:
    text_md = gr.Markdown(text)
    model_cp = gr.Textbox(value="Vipitis/santacoder-finetuned-Shadertoys-fine", label="Model Checkpoint (Enter to load!)", interactive=True)
    sample_idx = gr.Slider(minimum=0, maximum=num_samples, value=3211, label="pick sample from dataset", step=1.0)
    func_dropdown = gr.Dropdown(label="chose a function to modify") #breaks if I add a string in before that? 
    with gr.Row():
        gen_return_button = gr.Button("generate a alternate return statement", label="generate return")
        gen_func_button = gr.Button("generate an alternate function body", label="generate function")
        # update_funcs_button = gr.Button("update functions", label="update functions")
    render_button = gr.Button("render frame0 (can carsh the sapce on invalid shadercode)",label="render frame")
    time_slider = gr.Slider(minimum=0, maximum=10, value=0, label="time (update on release, also used to pick other functions as a bodge)", step=0.02)
    with gr.Row():
        with gr.Column():
            source_embed = gr.HTML('<iframe width="640" height="360" frameborder="0" src="https://www.shadertoy.com/embed/WsBcWV?gui=true&t=0&paused=true&muted=true" allowfullscreen></iframe>', label="How this shader originally renders")
            rendered_frame = gr.Image(shape=(512, 420), label=f"rendered frame preview", type="pil") #colors are messed up?
        sample_code = gr.Code(label="Current Code (will update changes you generate)", language=None)
    
    sample_pass = gr.State(value={})
    pipe = gr.State(value=PIPE)
    funcs = gr.State(value=[])
    # hist_state = gr.State(Value={})
    # history_table = gr.JSON()
    
    model_cp.submit(fn=_make_pipeline, inputs=[model_cp], outputs=[pipe])
    sample_idx.release(fn=grab_sample, inputs=[sample_idx], outputs=[sample_pass, sample_code, source_embed, funcs, func_dropdown])
    # sample_idx.release(fn=list_dropdown, inputs=[sample_code], outputs=[funcs, func_dropdown]) #use multiple event handles to call other functions! seems to not work really well. always messes up
    gen_return_button.click(fn=alter_return, inputs=[sample_code, time_slider, pipe], outputs=[sample_code])
    gen_func_button.click(fn=alter_body, inputs=[sample_code, func_dropdown, funcs, pipe], outputs=[sample_code])
    # run_button.click(fn=add_history, inputs=[time_slider, sample_pass, sample_code, hist_state], outputs=[history_table, hist_state])
    # sample_idx.release(fn=construct_embed, inputs=[sample_idx], outputs=[source_embed]) #twice to make have different outputs?
    time_slider.release(fn=lambda code, time: asyncio.run(get_image(code, time)), inputs=[sample_code, time_slider], outputs=rendered_frame)
    render_button.click(fn=lambda code: asyncio.run(get_image(code)), inputs=[sample_code], outputs=rendered_frame)
    # run_button.click(fn=print, inputs=[model_cp, sample_idx], outputs=output)
site.launch()