Spaces:
Build error
Build error
| from datasets.arrow_dataset import InMemoryTable | |
| import streamlit as st | |
| from PIL import Image, ImageDraw | |
| from streamlit_image_coordinates import streamlit_image_coordinates | |
| import numpy as np | |
| from datasets import load_dataset | |
| ds = load_dataset("Circularmachines/batch_indexing_machine_test", split="test") | |
| gridsize=16 | |
| n_patches=164 | |
| patch_size=32 | |
| pred=np.load('pred.npy') | |
| pred_all=np.load('pred_all.npy').reshape(-1,64) | |
| keep_bool=np.load('keep_bool.npy') | |
| keep=np.where(keep_bool.flatten())[0] | |
| keep_i=np.zeros(gridsize**2) | |
| keep_i[keep]=keep | |
| #st.set_page_config( | |
| # page_title="Streamlit Image Coordinates: Image Update", | |
| # page_icon="🎯", | |
| # layout="wide", | |
| #) | |
| #"# :dart: Streamlit Image Coordinates: Image Update" | |
| if "point" not in st.session_state: | |
| st.session_state["point"] = (200,200) | |
| if "img" not in st.session_state: | |
| st.session_state["img"] = 0 | |
| if "draw" not in st.session_state: | |
| st.session_state["draw"] = False | |
| def patch(ij): | |
| #st.write(ij) | |
| immg=ij//(gridsize**2) | |
| p=ij%(gridsize**2) | |
| imm=ds[int(immg)]['image'].resize(size=(512,512)) | |
| y=p//gridsize | |
| x=p%gridsize | |
| imc=imm.crop(((x-1)*patch_size,(y-1)*patch_size,(x+2)*patch_size,(y+2)*patch_size)) | |
| return imc | |
| if "sideimg" not in st.session_state: | |
| st.session_state["sideimg"] = [patch(i) for i in range(4)] | |
| def button_click(): | |
| st.session_state["img"]=np.random.randint(100) | |
| st.session_state["draw"] = False | |
| def find(): | |
| point=st.session_state["point"] | |
| point=(point[0]//patch_size,point[1]//patch_size) | |
| #point=point[0]*36+point[1] | |
| #st.write(point) | |
| #st.write(pred_all[st.session_state["img"],point[0]*36+point[1]]) | |
| i=st.session_state["img"] | |
| p=point[1]*gridsize+point[0] | |
| diff=np.linalg.norm(pred_all[np.newaxis,i*gridsize**2+p,:]-pred_all,axis=-1) | |
| for ix in range(4): | |
| st.session_state["sideimg"][ix]=patch(diff.argsort()[ix]) | |
| #st.write(diff.argsort()[ix]) | |
| # for i in range(4): | |
| # st.session_state["sideimg"][i]+=1 | |
| # st.image(ds[0]['image']) | |
| def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]): | |
| center = point | |
| #patch_size | |
| return ( | |
| center[0] , | |
| center[1] , | |
| center[0] + patch_size, | |
| center[1] + patch_size, | |
| ) | |
| col1, col2 = st.columns([5,1]) | |
| with col1: | |
| current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512)) | |
| draw = ImageDraw.Draw(current_image) | |
| if st.session_state["draw"]: | |
| # Draw an ellipse at each coordinate in points | |
| #for point in st.session_state["points"]: | |
| point=st.session_state["point"] | |
| coords = get_ellipse_coords(point) | |
| draw.rectangle(coords, outline="green",width=2) | |
| value = streamlit_image_coordinates(current_image, key="pil") | |
| if value is not None: | |
| point = value["x"]//patch_size*patch_size, value["y"]//patch_size*patch_size | |
| if point != st.session_state["point"]: | |
| st.session_state["point"]=point | |
| st.session_state["draw"]=True | |
| st.experimental_rerun() | |
| #subcol1, subcol2 = st.columns(2) | |
| #with subcol1: | |
| #st.button('Previous Frame', on_click=button_click) | |
| st.button('Change Batch', on_click=button_click) | |
| st.button('Find similar parts', on_click=find) | |
| st.write(st.session_state["img"]) | |
| st.write(st.session_state["point"]) | |
| st.write(st.session_state["draw"]) | |
| with col2: | |
| for i in range(4): | |
| st.image(st.session_state["sideimg"][i].resize((128,128))) | |