File size: 4,949 Bytes
9705209 df20613 beb3e34 b73b4e1 df20613 8cc7c1a 26b8b00 31b5ef1 86efe2e 31b5ef1 540d02b 23ecd39 540d02b 23ecd39 540d02b 23ecd39 540d02b 23ecd39 31b5ef1 540d02b 83f2c3c 540d02b 8cc7c1a 26b8b00 83f2c3c faf8b4b 26b8b00 775fea9 42bc3b0 3d79c89 fdef182 2814195 33632ae 83f2c3c 9e5d34c 83f2c3c 33632ae 2814195 83f2c3c 12ae958 83f2c3c 26b8b00 b1c653f b2b0a36 775fea9 fca50af 8015cdb faf8b4b 33632ae cc301d6 1ac8545 33632ae 065bb16 b1c653f 0b03953 fdef182 83f2c3c faf8b4b 12ae958 8015cdb 8cc7c1a df5841e 33632ae df5841e e86e66f de70cb6 df5841e beb3e34 854f030 aa4905a 391663d 12ae958 3d79c89 86efe2e 12ae958 854f030 12ae958 854f030 12ae958 faf8b4b 12ae958 aca188e 12ae958 aca188e 12ae958 33632ae df5841e faf8b4b 12ae958 df5841e e2813d1 5ff37df e5bac18 e2813d1 12ae958 e5bac18 12ae958 d4c5cf6 ddac224 d4c5cf6 065bb16 12ae958 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
from datasets.arrow_dataset import InMemoryTable
import streamlit as st
from PIL import Image, ImageDraw
from streamlit_image_coordinates import streamlit_image_coordinates
import numpy as np
from datasets import load_dataset
ds = load_dataset("Circularmachines/batch_indexing_machine_test", split="test")
patch_size=32
#image_size=2304
image_size=512
gridsize=16
def donut(patch_size, img_size, lower_limit=0.40, upper_limit=1):
gridsize=img_size//2//patch_size
#create a grid of patch coordinates relative to center of image, and calculate distance from center
coords=np.array([[(i+0.5,j+0.5) for i in range(-gridsize,gridsize)] for j in range(-gridsize,gridsize)])
norm=np.linalg.norm(coords,axis=2)
#we are only interested in the "donut" where the parts are, anything close to the center and far from the center is disregarded
keep_bool=((norm>(gridsize*lower_limit))*(norm<(gridsize*upper_limit)))
keep=np.where(keep_bool.flatten())[0]
return coords,keep,keep_bool
coords,keep,keep_bool=donut(patch_size,image_size)
#coords_valid=coords.reshape(-1,2)[keep]
n_patches=len(keep)
#angle_sort=(-np.arctan2(coords_valid[:,0],coords_valid[:,1])).argsort()
#keep_a=keep[angle_sort]
#keep_i=np.zeros(gridsize**2)
#keep_i[keep]=keep_a
pred=np.load('pred.npy')
pred_all=np.load('pred_all.npy').reshape(-1,64)
#st.set_page_config(
# page_title="Streamlit Image Coordinates: Image Update",
# page_icon="🎯",
# layout="wide",
#)
#"# :dart: Streamlit Image Coordinates: Image Update"
if "point" not in st.session_state:
st.session_state["point"] = (200,200)
if "img" not in st.session_state:
st.session_state["img"] = 0
if "draw" not in st.session_state:
st.session_state["draw"] = False
def patch(ij):
#st.write(ij)
immg=ij//(gridsize**2)
p=ij%(gridsize**2)
imm=ds[int(immg)]['image'].resize(size=(512,512))
y=p//gridsize
x=p%gridsize
imc=imm.crop(((x-1)*patch_size,(y-1)*patch_size,(x+2)*patch_size,(y+2)*patch_size))
return imc
if "sideimg" not in st.session_state:
st.session_state["sideimg"] = [patch(i) for i in range(4)]
if "sideix" not in st.session_state:
st.session_state["sideix"] = [i for i in range(4)]
def button_click():
st.session_state["img"]=np.random.randint(100)
st.session_state["draw"] = False
def find():
point=st.session_state["point"]
point=(point[0]//patch_size,point[1]//patch_size)
#point=point[0]*36+point[1]
#st.write(point)
#st.write(pred_all[st.session_state["img"],point[0]*36+point[1]])
i=st.session_state["img"]
p=point[1]*gridsize+point[0]
diff=np.linalg.norm(pred_all[np.newaxis,i*gridsize**2+p,:]-pred_all,axis=-1)
#re_pred=pred_all.reshape(20,20,256,64)
#diff_re=diff.reshape((20,20,256)).argmin(axis=[])
i=0
ix=0
batches=[]
while ix<4:
#for ix in range(4):
batch=diff.argsort()[i]//(gridsize**2)//20
if batch not in batches:
batches.append(batch)
st.session_state["sideimg"][ix]=patch(diff.argsort()[i])
ix+=1
i+=1
st.session_state["sideix"]=batches
#st.session_state["sideix"][ix]=diff.argsort()[ix]
#st.write(diff.argsort()[ix])
# for i in range(4):
# st.session_state["sideimg"][i]+=1
# st.image(ds[0]['image'])
def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]):
center = point
#patch_size
return (
center[0] ,
center[1] ,
center[0] + patch_size,
center[1] + patch_size,
)
col1, col2 = st.columns([5,1])
with col1:
current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512))
draw = ImageDraw.Draw(current_image)
if st.session_state["draw"]:
# Draw an ellipse at each coordinate in points
#for point in st.session_state["points"]:
point=st.session_state["point"]
coords = get_ellipse_coords(point)
draw.rectangle(coords, outline="green",width=2)
value = streamlit_image_coordinates(current_image, key="pil")
if value is not None:
point = value["x"]//patch_size*patch_size, value["y"]//patch_size*patch_size
if point != st.session_state["point"]:
st.session_state["point"]=point
st.session_state["draw"]=True
st.experimental_rerun()
#subcol1, subcol2 = st.columns(2)
#with subcol1:
#st.button('Previous Frame', on_click=button_click)
st.button('Change Image', on_click=button_click)
st.button('Find similar parts', on_click=find)
#st.write(st.session_state["img"])
#st.write(st.session_state["point"])
#st.write(st.session_state["draw"])
with col2:
# st.write("current selection:")
for i in range(4):
st.image(st.session_state["sideimg"][i].resize((128,128)))
if i==0:
st.write("current batch: "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
else:
st.write("other batch: "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
#st.write(st.session_state["sideix"][i])
|