<html>
<head>
<title>Stablediffusion Infinity</title>
<meta charset="utf-8">
<link rel="icon" type="image/x-icon" href="./favicon.png">
<link rel="stylesheet" href="https://pyscript.net/alpha/pyscript.css" />
<script defer src="https://pyscript.net/alpha/pyscript.js"></script>
<style>
#container {
  position: relative;
  margin:auto;
}
#container > canvas {
  position: absolute;
  top: 0;
  left: 0;
}
.control {
  display: none;
}
</style>

</head>
<body>
<div>
<button type="button" class="control" id="export">Export</button>
<button type="button" class="control" id="outpaint">Outpaint</button>
<button type="button" class="control" id="undo">Undo</button>
<button type="button" class="control" id="commit">Commit</button>
<button type="button" class="control" id="transfer">Transfer</button>
<button type="button" class="control" id="upload">Upload</button>
<button type="button" class="control" id="draw">Draw</button>
<input type="text" id="mode" value="✥" class="control">
<input type="text" id="setup" value="0" class="control">
<textarea rows="1" id="selbuffer" name="selbuffer" class="control"></textarea>
<fieldset class="control">
    <div>
      <input type="radio" id="mode0" name="mode" value="0" checked>
      <label for="mode0">SelBox</label>
    </div>
    <div>
      <input type="radio" id="mode1" name="mode" value="1">
      <label for="mode1">Image</label>
    </div>
    <div>
      <input type="radio" id="mode2" name="mode" value="2">
      <label for="mode2">Brush</label>
    </div>
</fieldset>
</div>
<div>
<div id = "container">
  <canvas id = "canvas0"></canvas>
  <canvas id = "canvas1"></canvas>
  <canvas id = "canvas2"></canvas>
  <canvas id = "canvas3"></canvas>
  <canvas id = "canvas4"></canvas>
</div>
</div>
<py-env>
- numpy
- Pillow
- paths:
  - ./canvas.py
</py-env>

<py-script>
from pyodide import to_js, create_proxy
from PIL import Image
import io
import time
import base64
import numpy as np
from js import (
    console,
    document,
    parent,
    devicePixelRatio,
    ImageData,
    Uint8ClampedArray,
    CanvasRenderingContext2D as Context2d,
    requestAnimationFrame,
    window
)


from canvas import InfCanvas



base_lst = [None]
async def draw_canvas() -> None:
    width=1500
    height=600
    canvas=InfCanvas(1500,600)
    document.querySelector("#container").style.width = f"{width}px"
    canvas.setup_mouse()
    canvas.clear_background()
    canvas.draw_buffer()
    canvas.draw_selection_box()
    base_lst[0]=canvas

async def draw_canvas_func(event):
    try:
        width=parent.document.querySelector("gradio-app").querySelector("#canvas_width input").value
        height=parent.document.querySelector("gradio-app").querySelector("#canvas_height input").value
        selection_size=parent.document.querySelector("gradio-app").querySelector("#selection_size input").value
    except:
        width=1024
        height=768
        selection_size=384
    document.querySelector("#container").style.width = f"{width}px"
    canvas=InfCanvas(int(width),int(height),selection_size=int(selection_size))
    canvas.setup_mouse()
    canvas.clear_background()
    canvas.draw_buffer()
    canvas.draw_selection_box()
    base_lst[0]=canvas
  
async def export_func(event):
    base=base_lst[0]
    arr=base.export()
    base64_str = base.numpy_to_base64(arr)
    time_str = time.strftime("%Y%m%d_%H%M%S")
    console.log(f"Canvas saved to outpaint_{time_str}.png")
    link = document.createElement("a")
    link.download = f"outpaint_{time_str}.png"
    link.href = "data:image/png;base64,"+base64_str
    link.click()

async def outpaint_func(event):
    base=base_lst[0]
    base64_str=event.data[1]
    arr=base.base64_to_numpy(base64_str)
    base.fill_selection(arr)
    base.draw_selection_box()

async def undo_func(event):
    base=base_lst[0]
    if base.sel_dirty:
        base.canvas[2].clear()
        base.sel_buffer = base.sel_buffer_bak.copy()
        base.sel_dirty = False

async def commit_func(event):
    base=base_lst[0]
    if base.sel_dirty:
        base.write_selection_to_buffer()

async def transfer_func(event):
    base=base_lst[0]
    base.read_selection_from_buffer()
    sel_buffer=base.sel_buffer
    sel_buffer_str=base.numpy_to_base64(sel_buffer)
    parent.postMessage(to_js(["transfer",str(sel_buffer_str)]),"*")

async def upload_func(event):
    base=base_lst[0]
    base64_str=event.data[1]
    arr=base.base64_to_numpy(base64_str)
    h,w,_=arr.shape
    yo=(base.height-h)//2
    xo=(base.width-w)//2
    if base.sel_dirty:
        base.canvas[2].clear()
        base.sel_buffer = base.sel_buffer_bak.copy()
        base.sel_dirty = False
    base.buffer_dirty=True
    base.buffer*=0
    base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3]
    base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1]
    base.draw_buffer()

  

document.querySelector("#export").addEventListener("click",create_proxy(export_func))
document.querySelector("#undo").addEventListener("click",create_proxy(undo_func))
document.querySelector("#commit").addEventListener("click",create_proxy(commit_func))
document.querySelector("#outpaint").addEventListener("click",create_proxy(outpaint_func))
document.querySelector("#upload").addEventListener("click",create_proxy(upload_func))

document.querySelector("#transfer").addEventListener("click",create_proxy(transfer_func))
document.querySelector("#draw").addEventListener("click",create_proxy(draw_canvas_func))

async def setup_func():
    document.querySelector("#setup").value="1"

async def message_func(event):
    if event.data[0]=="click":
        if event.data[1]=="export":
            await export_func(event)
        elif event.data[1]=="commit":
            await commit_func(event)
        elif event.data[1]=="undo":
            await undo_func(event)
    elif event.data[0]=="upload":
        await upload_func(event)
    elif event.data[0]=="outpaint":
        await outpaint_func(event)
    elif event.data[0]=="mode":
        document.querySelector("#mode").value=event.data[1]
    elif event.data[0]=="transfer":
        await transfer_func(event)
    
window.addEventListener("message",create_proxy(message_func))
import asyncio

_ = await asyncio.gather(
  setup_func(),draw_canvas()
)
</py-script>

</body>
</html>