Saghir commited on
Commit
28de83f
·
1 Parent(s): 7cc389b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -10
app.py CHANGED
@@ -4,6 +4,7 @@ import matplotlib.pyplot as plt
4
  from torchvision import transforms
5
  from PIL import Image
6
  import torch.nn as nn
 
7
 
8
  from PathDino import get_pathDino_model
9
 
@@ -82,13 +83,19 @@ def visualize_attention_ViT(model, img, patch_size=16):
82
  return attention_list
83
 
84
  # Define the function to generate activation maps
85
- def generate_activation_maps(image):
 
 
 
 
 
 
86
  preprocess = transforms.Compose([
87
- transforms.Resize((512, 512)),
88
- transforms.CenterCrop(512),
89
- transforms.ToTensor(),
90
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors
91
- ])
92
  image_tensor = preprocess(image)
93
  img = image_tensor.unsqueeze(0).to(device)
94
  # Generate activation maps
@@ -106,15 +113,16 @@ uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]
106
 
107
  if uploaded_image is not None:
108
  columns = st.columns(3)
109
- columns[1].image(uploaded_image, caption="Uploaded Image", width=300)
110
 
111
  # Load the image and apply preprocessing
112
  uploaded_image = Image.open(uploaded_image).convert('RGB')
113
  attention_list = generate_activation_maps(uploaded_image)
114
  print(len(attention_list))
115
  st.subheader(f"Attention Maps of the input image")
116
- columns = st.columns(len(attention_list)//2)
117
- columns2 = st.columns(len(attention_list)//2)
 
118
  for index, col in enumerate(columns):
119
  # Create a plot
120
  plt.plot(512, 512)
@@ -133,7 +141,25 @@ if uploaded_image is not None:
133
 
134
  for index, col in enumerate(columns2):
135
 
136
- index = index + len(attention_list)//2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  # Create a plot
138
  plt.plot(512, 512)
139
 
 
4
  from torchvision import transforms
5
  from PIL import Image
6
  import torch.nn as nn
7
+ import numpy as np
8
 
9
  from PathDino import get_pathDino_model
10
 
 
83
  return attention_list
84
 
85
  # Define the function to generate activation maps
86
+ def generate_activation_maps(image, patch_size=16):
87
+ # Convert the image to a NumPy array
88
+ img = np.array(image)
89
+ # make the image divisible by the patch size
90
+ w, h = img.shape[1] - img.shape[0] % patch_size, img.shape[1] - img.shape[1] % patch_size
91
+ min_size = min(w, h)
92
+ print("Image shape:", img.shape)
93
  preprocess = transforms.Compose([
94
+ transforms.Resize((img.shape[0], img.shape[1])),
95
+ transforms.CenterCrop(min_size),
96
+ transforms.ToTensor(),
97
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors
98
+ ])
99
  image_tensor = preprocess(image)
100
  img = image_tensor.unsqueeze(0).to(device)
101
  # Generate activation maps
 
113
 
114
  if uploaded_image is not None:
115
  columns = st.columns(3)
116
+ columns[1].image(uploaded_image, caption="Uploaded Image", width=400)
117
 
118
  # Load the image and apply preprocessing
119
  uploaded_image = Image.open(uploaded_image).convert('RGB')
120
  attention_list = generate_activation_maps(uploaded_image)
121
  print(len(attention_list))
122
  st.subheader(f"Attention Maps of the input image")
123
+ columns = st.columns(2)
124
+ columns2 = st.columns(2)
125
+ columns3 = st.columns(2)
126
  for index, col in enumerate(columns):
127
  # Create a plot
128
  plt.plot(512, 512)
 
141
 
142
  for index, col in enumerate(columns2):
143
 
144
+ index = index + 2
145
+ # Create a plot
146
+ plt.plot(512, 512)
147
+
148
+ # Remove x and y axis labels
149
+ plt.xticks([]) # Hide x-axis ticks and labels
150
+ plt.yticks([]) # Hide y-axis ticks and labels
151
+
152
+ # Alternatively, if you only want to hide the labels and keep the ticks:
153
+ plt.gca().axes.get_xaxis().set_visible(False)
154
+ plt.gca().axes.get_yaxis().set_visible(False)
155
+
156
+ plt.imshow(attention_list[index])
157
+ col.pyplot(plt)
158
+ plt.close()
159
+
160
+ for index, col in enumerate(columns3):
161
+
162
+ index = index + 4
163
  # Create a plot
164
  plt.plot(512, 512)
165