from datasets.arrow_dataset import InMemoryTable |
import streamlit as st |
from PIL import Image, ImageDraw |
from streamlit_image_coordinates import streamlit_image_coordinates |
import numpy as np |
from datasets import load_dataset |
ds = load_dataset("Circularmachines/batch_indexing_machine_test", split="test") |
gridsize=16 |
n_patches=164 |
patch_size=32 |
pred=np.load('pred.npy') |
pred_all=np.load('pred_all.npy').reshape(-1,64) |
keep_bool=np.load('keep_bool.npy') |
keep=np.where(keep_bool.flatten())[0] |
keep_i=np.zeros(gridsize**2) |
keep_i[keep]=keep |
if "point" not in st.session_state: |
st.session_state["point"] = (200,200) |
if "img" not in st.session_state: |
st.session_state["img"] = 0 |
if "draw" not in st.session_state: |
st.session_state["draw"] = False |
def patch(ij): |
immg=ij//(gridsize**2) |
p=ij%(gridsize**2) |
imm=ds[int(immg)]['image'].resize(size=(512,512)) |
y=p//gridsize |
x=p%gridsize |
imc=imm.crop(((x-1)*patch_size,(y-1)*patch_size,(x+2)*patch_size,(y+2)*patch_size)) |
return imc |
if "sideimg" not in st.session_state: |
st.session_state["sideimg"] = [patch(i) for i in range(4)] |
def button_click(): |
st.session_state["img"]=np.random.randint(100) |
st.session_state["draw"] = False |
def find(): |
point=st.session_state["point"] |
point=(point[0]//patch_size,point[1]//patch_size) |
i=st.session_state["img"] |
p=point[1]*gridsize+point[0] |
diff=np.linalg.norm(pred_all[np.newaxis,i*gridsize**2+p,:]-pred_all,axis=-1) |
for ix in range(4): |
st.session_state["sideimg"][ix]=patch(diff.argsort()[ix]) |
def get_ellipse_coords(point): |
center = point |
return ( |
center[0] , |
center[1] , |
center[0] + patch_size, |
center[1] + patch_size, |
) |
col1, col2 = st.columns([5,1]) |
with col1: |
current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512)) |
draw = ImageDraw.Draw(current_image) |
if st.session_state["draw"]: |
point=st.session_state["point"] |
coords = get_ellipse_coords(point) |
draw.rectangle(coords, outline="green",width=2) |
value = streamlit_image_coordinates(current_image, key="pil") |
if value is not None: |
point = value["x"]//patch_size*patch_size, value["y"]//patch_size*patch_size |
if point != st.session_state["point"]: |
st.session_state["point"]=point |
st.session_state["draw"]=True |
st.experimental_rerun() |
st.button('Change Batch', on_click=button_click) |
st.button('Find similar parts', on_click=find) |
st.write(st.session_state["img"]) |
st.write(st.session_state["point"]) |
st.write(st.session_state["draw"]) |
with col2: |
for i in range(4): |
st.image(st.session_state["sideimg"][i].resize((128,128))) |