File size: 4,949 Bytes
9705209
df20613
beb3e34
b73b4e1
 
df20613
8cc7c1a
26b8b00
31b5ef1
 
86efe2e
31b5ef1
540d02b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23ecd39
540d02b
 
23ecd39
 
540d02b
23ecd39
540d02b
23ecd39
31b5ef1
540d02b
 
 
83f2c3c
540d02b
8cc7c1a
 
 
 
 
26b8b00
 
 
83f2c3c
 
faf8b4b
 
26b8b00
775fea9
 
42bc3b0
3d79c89
 
 
fdef182
2814195
33632ae
 
83f2c3c
9e5d34c
83f2c3c
33632ae
 
2814195
83f2c3c
 
12ae958
83f2c3c
26b8b00
b1c653f
 
 
b2b0a36
775fea9
fca50af
8015cdb
 
faf8b4b
33632ae
cc301d6
1ac8545
 
 
33632ae
 
065bb16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1c653f
0b03953
fdef182
83f2c3c
 
faf8b4b
 
 
12ae958
8015cdb
8cc7c1a
df5841e
 
33632ae
df5841e
e86e66f
 
de70cb6
 
df5841e
beb3e34
854f030
aa4905a
391663d
12ae958
3d79c89
86efe2e
12ae958
854f030
12ae958
854f030
12ae958
 
faf8b4b
12ae958
 
aca188e
12ae958
aca188e
12ae958
33632ae
df5841e
faf8b4b
 
12ae958
 
df5841e
e2813d1
 
5ff37df
e5bac18
e2813d1
12ae958
e5bac18
 
 
12ae958
 
d4c5cf6
ddac224
 
d4c5cf6
 
 
 
 
 
065bb16
12ae958
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from datasets.arrow_dataset import InMemoryTable
import streamlit as st
from PIL import Image, ImageDraw

from streamlit_image_coordinates import streamlit_image_coordinates

import numpy as np

from datasets import load_dataset

ds = load_dataset("Circularmachines/batch_indexing_machine_test", split="test")

patch_size=32
#image_size=2304
image_size=512
gridsize=16

def donut(patch_size, img_size, lower_limit=0.40, upper_limit=1):
  
  gridsize=img_size//2//patch_size

  #create a grid of patch coordinates relative to center of image, and calculate distance from center
  coords=np.array([[(i+0.5,j+0.5) for i in range(-gridsize,gridsize)] for j in range(-gridsize,gridsize)])
  norm=np.linalg.norm(coords,axis=2)
  
  #we are only interested in the "donut" where the parts are, anything close to the center and far from the center is disregarded
  keep_bool=((norm>(gridsize*lower_limit))*(norm<(gridsize*upper_limit)))
  keep=np.where(keep_bool.flatten())[0]

  return coords,keep,keep_bool

coords,keep,keep_bool=donut(patch_size,image_size)
#coords_valid=coords.reshape(-1,2)[keep]
n_patches=len(keep)

#angle_sort=(-np.arctan2(coords_valid[:,0],coords_valid[:,1])).argsort()
#keep_a=keep[angle_sort]

#keep_i=np.zeros(gridsize**2)

#keep_i[keep]=keep_a

pred=np.load('pred.npy')
pred_all=np.load('pred_all.npy').reshape(-1,64)



#st.set_page_config(
#    page_title="Streamlit Image Coordinates: Image Update",
#    page_icon="🎯",
#    layout="wide",
#)

#"# :dart: Streamlit Image Coordinates: Image Update"



if "point" not in st.session_state:
  st.session_state["point"] = (200,200)

if "img" not in st.session_state:
  st.session_state["img"] = 0

if "draw" not in st.session_state:
  st.session_state["draw"] = False

def patch(ij):
  #st.write(ij)
  immg=ij//(gridsize**2)
  p=ij%(gridsize**2)

  imm=ds[int(immg)]['image'].resize(size=(512,512))

  y=p//gridsize
  x=p%gridsize
  imc=imm.crop(((x-1)*patch_size,(y-1)*patch_size,(x+2)*patch_size,(y+2)*patch_size))
  return imc

if "sideimg" not in st.session_state:
  st.session_state["sideimg"] = [patch(i) for i in range(4)]

if "sideix" not in st.session_state:
  st.session_state["sideix"] = [i for i in range(4)]

def button_click():
    st.session_state["img"]=np.random.randint(100)
    st.session_state["draw"] = False

def find():
  point=st.session_state["point"]
  point=(point[0]//patch_size,point[1]//patch_size)
  #point=point[0]*36+point[1]
  #st.write(point)
  #st.write(pred_all[st.session_state["img"],point[0]*36+point[1]])
  i=st.session_state["img"]
  p=point[1]*gridsize+point[0]
  diff=np.linalg.norm(pred_all[np.newaxis,i*gridsize**2+p,:]-pred_all,axis=-1)
  #re_pred=pred_all.reshape(20,20,256,64)
  #diff_re=diff.reshape((20,20,256)).argmin(axis=[])
  i=0
  ix=0
  batches=[]
  while ix<4:

  #for ix in range(4):

    batch=diff.argsort()[i]//(gridsize**2)//20

    if batch not in batches:

      batches.append(batch)

      st.session_state["sideimg"][ix]=patch(diff.argsort()[i])
      ix+=1
    
    i+=1
  
    st.session_state["sideix"]=batches

    
    #st.session_state["sideix"][ix]=diff.argsort()[ix]

    #st.write(diff.argsort()[ix])




 # for i in range(4):
 #   st.session_state["sideimg"][i]+=1
 #   st.image(ds[0]['image'])


def get_ellipse_coords(point):# tuple[int, int]) -> tuple[int, int, int, int]):
    center = point
    #patch_size
    return (
        center[0] ,
        center[1] ,
        center[0] + patch_size,
        center[1] + patch_size,
    )


col1, col2 = st.columns([5,1])

with col1:

  current_image=ds[st.session_state["img"]]['image'].resize(size=(512,512))
  draw = ImageDraw.Draw(current_image)

  if st.session_state["draw"]:

  # Draw an ellipse at each coordinate in points
    #for point in st.session_state["points"]:
    point=st.session_state["point"]
    coords = get_ellipse_coords(point)
    draw.rectangle(coords, outline="green",width=2)

  value = streamlit_image_coordinates(current_image, key="pil")

  if value is not None:
      point = value["x"]//patch_size*patch_size, value["y"]//patch_size*patch_size

      if point != st.session_state["point"]:
          st.session_state["point"]=point
          st.session_state["draw"]=True
          st.experimental_rerun()

  #subcol1, subcol2 = st.columns(2)
  #with subcol1:
  #st.button('Previous Frame', on_click=button_click)
  st.button('Change Image', on_click=button_click)
  st.button('Find similar parts', on_click=find)

  #st.write(st.session_state["img"])
  #st.write(st.session_state["point"])
  #st.write(st.session_state["draw"])

with col2:
 # st.write("current selection:")
  for i in range(4):
    st.image(st.session_state["sideimg"][i].resize((128,128)))
    if i==0:
      st.write("current batch: "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
    else:
      st.write("other batch: "+str(st.session_state["sideix"][i]))#//(gridsize**2)//20))
   
    
    #st.write(st.session_state["sideix"][i])