Saghir commited on
Commit
be8436c
·
1 Parent(s): 28de83f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -88,11 +88,12 @@ def generate_activation_maps(image, patch_size=16):
88
  img = np.array(image)
89
  # make the image divisible by the patch size
90
  w, h = img.shape[1] - img.shape[0] % patch_size, img.shape[1] - img.shape[1] % patch_size
91
- min_size = min(w, h)
 
92
  print("Image shape:", img.shape)
93
  preprocess = transforms.Compose([
94
  transforms.Resize((img.shape[0], img.shape[1])),
95
- transforms.CenterCrop(min_size),
96
  transforms.ToTensor(),
97
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors
98
  ])
@@ -105,27 +106,29 @@ def generate_activation_maps(image, patch_size=16):
105
 
106
  # Streamlit UI
107
  st.title("PathDino - Compact ViT for Histopathology Image Analysis")
108
- st.write("Upload a histology image with 512x512 dimension of 20X magnification to view the activation maps.")
109
 
110
  # uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
111
  uploaded_image = "images/HistRotate.png"
112
  uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
113
 
114
  if uploaded_image is not None:
115
- columns = st.columns(3)
116
- columns[1].image(uploaded_image, caption="Uploaded Image", width=400)
117
 
118
  # Load the image and apply preprocessing
119
  uploaded_image = Image.open(uploaded_image).convert('RGB')
120
  attention_list = generate_activation_maps(uploaded_image)
 
121
  print(len(attention_list))
122
  st.subheader(f"Attention Maps of the input image")
123
  columns = st.columns(2)
124
  columns2 = st.columns(2)
125
  columns3 = st.columns(2)
 
126
  for index, col in enumerate(columns):
127
  # Create a plot
128
- plt.plot(512, 512)
129
 
130
  # Remove x and y axis labels
131
  plt.xticks([]) # Hide x-axis ticks and labels
@@ -134,16 +137,21 @@ if uploaded_image is not None:
134
  # Alternatively, if you only want to hide the labels and keep the ticks:
135
  plt.gca().axes.get_xaxis().set_visible(False)
136
  plt.gca().axes.get_yaxis().set_visible(False)
 
 
 
137
 
138
  plt.imshow(attention_list[index])
139
  col.pyplot(plt)
 
 
140
  plt.close()
141
 
142
  for index, col in enumerate(columns2):
143
 
144
  index = index + 2
145
  # Create a plot
146
- plt.plot(512, 512)
147
 
148
  # Remove x and y axis labels
149
  plt.xticks([]) # Hide x-axis ticks and labels
@@ -161,7 +169,7 @@ if uploaded_image is not None:
161
 
162
  index = index + 4
163
  # Create a plot
164
- plt.plot(512, 512)
165
 
166
  # Remove x and y axis labels
167
  plt.xticks([]) # Hide x-axis ticks and labels
 
88
  img = np.array(image)
89
  # make the image divisible by the patch size
90
  w, h = img.shape[1] - img.shape[0] % patch_size, img.shape[1] - img.shape[1] % patch_size
91
+ print("w, h:", w, h)
92
+ # min_size = min(w, h)
93
  print("Image shape:", img.shape)
94
  preprocess = transforms.Compose([
95
  transforms.Resize((img.shape[0], img.shape[1])),
96
+ transforms.CenterCrop((w, h)),
97
  transforms.ToTensor(),
98
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors
99
  ])
 
106
 
107
  # Streamlit UI
108
  st.title("PathDino - Compact ViT for Histopathology Image Analysis")
109
+ st.write("Upload a histology image to view the activation maps.")
110
 
111
  # uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
112
  uploaded_image = "images/HistRotate.png"
113
  uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
114
 
115
  if uploaded_image is not None:
116
+ # columns = st.columns(3)
117
+ st.image(uploaded_image, caption="Uploaded Image", width=500)
118
 
119
  # Load the image and apply preprocessing
120
  uploaded_image = Image.open(uploaded_image).convert('RGB')
121
  attention_list = generate_activation_maps(uploaded_image)
122
+
123
  print(len(attention_list))
124
  st.subheader(f"Attention Maps of the input image")
125
  columns = st.columns(2)
126
  columns2 = st.columns(2)
127
  columns3 = st.columns(2)
128
+ # for index in range(6):
129
  for index, col in enumerate(columns):
130
  # Create a plot
131
+ plt.plot(600, 600)
132
 
133
  # Remove x and y axis labels
134
  plt.xticks([]) # Hide x-axis ticks and labels
 
137
  # Alternatively, if you only want to hide the labels and keep the ticks:
138
  plt.gca().axes.get_xaxis().set_visible(False)
139
  plt.gca().axes.get_yaxis().set_visible(False)
140
+
141
+ print(type(attention_list[index]))
142
+ print(attention_list[index].shape)
143
 
144
  plt.imshow(attention_list[index])
145
  col.pyplot(plt)
146
+ # col
147
+ # st.image(plt, caption=f"Head-{index+1}", width=display_w)
148
  plt.close()
149
 
150
  for index, col in enumerate(columns2):
151
 
152
  index = index + 2
153
  # Create a plot
154
+ plt.plot(600, 600)
155
 
156
  # Remove x and y axis labels
157
  plt.xticks([]) # Hide x-axis ticks and labels
 
169
 
170
  index = index + 4
171
  # Create a plot
172
+ plt.plot(600, 600)
173
 
174
  # Remove x and y axis labels
175
  plt.xticks([]) # Hide x-axis ticks and labels