Circularmachines commited on
Commit
66b9c9d
·
1 Parent(s): 26d4876
Files changed (1) hide show
  1. app.py +16 -39
app.py CHANGED
@@ -6,54 +6,31 @@ from streamlit_image_coordinates import streamlit_image_coordinates
6
 
7
  import numpy as np
8
 
9
- from datasets import load_dataset
10
-
11
- ds = load_dataset("Circularmachines/batch_indexing_machine_test", split="test")
12
 
13
- patch_size=32
14
- #image_size=2304
15
- image_size=512
16
- gridsize=16
17
 
18
- def donut(patch_size, img_size, lower_limit=0.40, upper_limit=1):
19
-
20
- gridsize=img_size//2//patch_size
21
 
22
- #create a grid of patch coordinates relative to center of image, and calculate distance from center
23
- coords=np.array([[(i+0.5,j+0.5) for i in range(-gridsize,gridsize)] for j in range(-gridsize,gridsize)])
24
- norm=np.linalg.norm(coords,axis=2)
25
-
26
- #we are only interested in the "donut" where the parts are, anything close to the center and far from the center is disregarded
27
- keep_bool=((norm>(gridsize*lower_limit))*(norm<(gridsize*upper_limit)))
28
- keep=np.where(keep_bool.flatten())[0]
29
 
30
- return coords,keep,keep_bool
31
 
32
- coords,keep,keep_bool=donut(patch_size,image_size)
33
- #coords_valid=coords.reshape(-1,2)[keep]
34
- n_patches=len(keep)
35
 
36
- #angle_sort=(-np.arctan2(coords_valid[:,0],coords_valid[:,1])).argsort()
37
- #keep_a=keep[angle_sort]
38
 
39
- #keep_i=np.zeros(gridsize**2)
40
 
41
- #keep_i[keep]=keep_a
42
 
43
- pred=np.load('pred.npy')
44
- pred_all=np.load('pred_all.npy').reshape(-1,64)
45
 
46
- random_i=np.load('random.npy')
47
 
 
 
 
 
 
48
 
49
- #st.set_page_config(
50
- # page_title="Streamlit Image Coordinates: Image Update",
51
- # page_icon="🎯",
52
- # layout="wide",
53
- #)
54
 
55
- #"# :dart: Streamlit Image Coordinates: Image Update"
 
56
 
 
57
 
58
 
59
  if "point" not in st.session_state:
@@ -67,10 +44,10 @@ if "draw" not in st.session_state:
67
 
68
  def patch(ij):
69
  #st.write(ij)
70
- immg=ij//(gridsize**2)
71
- p=ij%(gridsize**2)
72
 
73
- imm=ds[int(immg)]['image'].resize(size=(512,512))
74
 
75
  y=p//gridsize
76
  x=p%gridsize
@@ -94,7 +71,7 @@ def find():
94
  batches=[]
95
  while ix<4:
96
 
97
- batch=diff.argsort()[i]//(gridsize**2)//20
98
 
99
  if batch not in batches:
100
 
@@ -158,7 +135,7 @@ with col1:
158
  value = streamlit_image_coordinates(current_image, key="pil")
159
 
160
  if value is not None:
161
- point = value["x"]//patch_size*patch_size, value["y"]//patch_size*patch_size
162
 
163
  if point != st.session_state["point"]:
164
  st.session_state["point"]=point
 
6
 
7
  import numpy as np
8
 
 
 
 
9
 
10
+ from datasets import load_dataset
 
 
 
11
 
 
 
 
12
 
 
 
 
 
 
 
 
13
 
 
14
 
 
 
 
15
 
16
+ ds = load_dataset("Circularmachines/batch_indexing_machine_test", split="test")
 
17
 
 
18
 
 
19
 
 
 
20
 
 
21
 
22
+ patch_size=32
23
+ stride=16
24
+ #image_size=2304
25
+ image_size=512
26
+ gridsize=31
27
 
28
+ n_patches=961
 
 
 
 
29
 
30
+ pred=np.load('pred.npy')
31
+ pred_all=np.load('pred_all.npy')#.reshape(-1,64)
32
 
33
+ random_i=np.load('random.npy')
34
 
35
 
36
  if "point" not in st.session_state:
 
44
 
45
  def patch(ij):
46
  #st.write(ij)
47
+ immg=ij//n_patches
48
+ p=ij%n_patches
49
 
50
+ imm=ds[int(immg)]['image']#.resize(size=(512,512))
51
 
52
  y=p//gridsize
53
  x=p%gridsize
 
71
  batches=[]
72
  while ix<4:
73
 
74
+ batch=diff.argsort()[i]//n_patches//20
75
 
76
  if batch not in batches:
77
 
 
135
  value = streamlit_image_coordinates(current_image, key="pil")
136
 
137
  if value is not None:
138
+ point = value["x"]//32, value["y"]//32
139
 
140
  if point != st.session_state["point"]:
141
  st.session_state["point"]=point