import streamlit as st from PIL import Image, ImageDraw from streamlit_image_coordinates import streamlit_image_coordinates import numpy as np from datasets import load_dataset ds = load_dataset("Circularmachines/batch_indexing_machine_100_small_imgs", split="train") pred=np.load('pred.npy') pred_all=np.load('pred_all.npy') keep_bool=np.load('keep_bool.npy') #st.set_page_config( # page_title="Streamlit Image Coordinates: Image Update", # page_icon="🎯", # layout="wide", #) #"# :dart: Streamlit Image Coordinates: Image Update" if "point" not in st.session_state: st.session_state["point"] = (200,200) if "img" not in st.session_state: st.session_state["img"] = 0 if "draw" not in st.session_state: st.session_state["draw"] = False if "sideimg" not in st.session_state: st.session_state["sideimg"] = [0,1,2,3] def button_click(): st.session_state["img"]=np.random.randint(100) st.session_state["draw"] = False def find(): point=st.session_state["point"] point=(point[0]//16,point[1]//16) #point=point[0]*36+point[1] #st.write(point) #st.write(pred_all[st.session_state["img"],point[0]*36+point[1]]) i=st.session_state["img"] p=point[0]*36+point[1] diff=np.linalg.norm(pred_all[np.newaxis,np.newaxis,i,p]-pred_all,axis=-1) st.write(diff.argmin()) # for i in range(4): # st.session_state["sideimg"][i]+=1 # st.image(ds[0]['image']) def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]): center = point patch_size = 16 return ( center[0] , center[1] , center[0] + patch_size, center[1] + patch_size, ) col1, col2 = st.columns([5,1]) with col1: current_image=ds[st.session_state["img"]]['image']#.resize(size=(384,384)) draw = ImageDraw.Draw(current_image) if st.session_state["draw"]: # Draw an ellipse at each coordinate in points #for point in st.session_state["points"]: point=st.session_state["point"] coords = get_ellipse_coords(point) draw.rectangle(coords, outline="green",width=2) value = streamlit_image_coordinates(current_image, key="pil") if value is not None: point = value["x"]//16*16, value["y"]//16*16 if point != st.session_state["point"]: st.session_state["point"]=point st.session_state["draw"]=True st.experimental_rerun() #subcol1, subcol2 = st.columns(2) #with subcol1: #st.button('Previous Frame', on_click=button_click) st.button('Change Batch', on_click=button_click) st.button('Find similar parts', on_click=find) st.write(st.session_state["img"]) st.write(st.session_state["point"]) st.write(st.session_state["draw"]) with col2: for i in range(3): st.image(np.array(ds[st.session_state["sideimg"][i]]['image'])[::4,::4,:])