BhumikaMak commited on
Commit
a2cda19
·
verified ·
1 Parent(s): a2ea702

update: added dff support

Browse files
Files changed (1) hide show
  1. yolov8.py +173 -1
yolov8.py CHANGED
@@ -63,6 +63,178 @@ def xai_yolov8s(image):
63
  tensor = transform(img_float).unsqueeze(0)
64
  target_layers = [model.model.model[-2]] # Adjust to YOLOv8 architecture
65
  cam_image, renormalized_cam_image = generate_cam_image(model.model, target_layers, tensor, image, boxes)
 
 
66
  final_image = np.hstack((image, detections_img, renormalized_cam_image))
67
  caption = "Results using YOLOv8"
68
- return Image.fromarray(final_image), caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  tensor = transform(img_float).unsqueeze(0)
64
  target_layers = [model.model.model[-2]] # Adjust to YOLOv8 architecture
65
  cam_image, renormalized_cam_image = generate_cam_image(model.model, target_layers, tensor, image, boxes)
66
+ rgb_img_float, batch_explanations, result = dff_nmf(image, target_lyr = -5, n_components = 8)
67
+
68
  final_image = np.hstack((image, detections_img, renormalized_cam_image))
69
  caption = "Results using YOLOv8"
70
+ return Image.fromarray(final_image), caption, result
71
+
72
+
73
+ def dff_l(activations, model, n_components):
74
+ batch_size, channels, h, w = activations.shape
75
+ print('activation', activations.shape)
76
+ target_layer_index = 4
77
+ reshaped_activations = activations.transpose((1, 0, 2, 3))
78
+ reshaped_activations[np.isnan(reshaped_activations)] = 0
79
+ reshaped_activations = reshaped_activations.reshape(
80
+ reshaped_activations.shape[0], -1)
81
+ offset = reshaped_activations.min(axis=-1)
82
+ reshaped_activations = reshaped_activations - offset[:, None]
83
+ model = NMF(n_components=n_components, init='random', random_state=0)
84
+ W = model.fit_transform(reshaped_activations)
85
+ H = model.components_
86
+ concepts = W + offset[:, None]
87
+ explanations = H.reshape(n_components, batch_size, h, w)
88
+ explanations = explanations.transpose((1, 0, 2, 3))
89
+ return concepts, explanations
90
+
91
+ class DeepFeatureFactorization:
92
+ def __init__(self,
93
+ model: torch.nn.Module,
94
+ target_layer: torch.nn.Module,
95
+ reshape_transform: Callable = None,
96
+ computation_on_concepts=None
97
+ ):
98
+ self.model = model
99
+ self.computation_on_concepts = computation_on_concepts
100
+ self.activations_and_grads = ActivationsAndGradients(
101
+ self.model, [target_layer], reshape_transform)
102
+
103
+ def __call__(self,
104
+ input_tensor: torch.Tensor,
105
+ model: torch.nn.Module,
106
+ n_components: int = 16):
107
+ if isinstance(input_tensor, np.ndarray):
108
+ input_tensor = torch.from_numpy(input_tensor)
109
+
110
+ batch_size, channels, h, w = input_tensor.size()
111
+ _ = self.activations_and_grads(input_tensor)
112
+
113
+ with torch.no_grad():
114
+ activations = self.activations_and_grads.activations[0].cpu(
115
+ ).numpy()
116
+
117
+ concepts, explanations = dff_l(activations, model, n_components=n_components)
118
+ processed_explanations = []
119
+
120
+ for batch in explanations:
121
+ processed_explanations.append(scale_cam_image(batch, (w, h)))
122
+
123
+ if self.computation_on_concepts:
124
+ with torch.no_grad():
125
+ concept_tensors = torch.from_numpy(
126
+ np.float32(concepts).transpose((1, 0)))
127
+ concept_outputs = self.computation_on_concepts(
128
+ concept_tensors).cpu().numpy()
129
+ return concepts, processed_explanations, concept_outputs
130
+ else:
131
+ return concepts, processed_explanations, explanations
132
+
133
+ def __del__(self):
134
+ self.activations_and_grads.release()
135
+
136
+ def __exit__(self, exc_type, exc_value, exc_tb):
137
+ self.activations_and_grads.release()
138
+ if isinstance(exc_value, IndexError):
139
+ # Handle IndexError here...
140
+ print(
141
+ f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}")
142
+ return True
143
+
144
+
145
+
146
+
147
+ def dff_nmf(image, target_lyr, n_components):
148
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
149
+ mean = [0.485, 0.456, 0.406] # Mean for RGB channels
150
+ std = [0.229, 0.224, 0.225] # Standard deviation for RGB channels
151
+ img = cv2.resize(image, (640, 640))
152
+ rgb_img_float = np.float32(img) / 255.0
153
+ input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
154
+
155
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(device)
156
+ dff= DeepFeatureFactorization(model=model,
157
+ target_layer=model.model.model[int(target_lyr)],
158
+ computation_on_concepts=None)
159
+
160
+ concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
161
+
162
+
163
+ yolov5_categories_url = \
164
+ "https://github.com/ultralytics/yolov5/raw/master/data/coco128.yaml" # URL to the YOLOv5 categories file
165
+ yaml_data = requests.get(yolov5_categories_url).text
166
+ labels = yaml.safe_load(yaml_data)['names'] # Parse the YAML file to get class names
167
+ num_classes = model.model.model[-1].nc
168
+ results = []
169
+ for indx in range(explanations[0].shape[0]):
170
+ upsampled_input = explanations[0][indx]
171
+ upsampled_input = torch.tensor(upsampled_input)
172
+ device = next(model.parameters()).device
173
+ input_tensor = upsampled_input.unsqueeze(0)
174
+ input_tensor = input_tensor.unsqueeze(1).repeat(1, 128, 1, 1)
175
+ detection_lyr = model.model.model[-1]
176
+ output1 = detection_lyr.m[0](input_tensor.to(device))
177
+ objectness = output1[..., 4] # Objectness score (index 4)
178
+ class_scores = output1[..., 5:] # Class scores (from index 5 onwards, representing 80 classes)
179
+ objectness = torch.sigmoid(objectness)
180
+ class_scores = torch.sigmoid(class_scores)
181
+ confidence_mask = objectness > 0.5
182
+ objectness = objectness[confidence_mask]
183
+ class_scores = class_scores[confidence_mask]
184
+ scores, class_ids = class_scores.max(dim=-1) # Get max class score per cell
185
+ scores = scores * objectness # Adjust scores by objectness
186
+ boxes = output1[..., :4] # First 4 values are x1, y1, x2, y2
187
+ boxes = boxes[confidence_mask] # Filter boxes by confidence mask
188
+ fig, ax = plt.subplots(1, figsize=(8, 8))
189
+ ax.axis("off")
190
+ ax.imshow(torch.tensor(batch_explanations[0][indx]).cpu().numpy(), cmap="plasma") # Display image
191
+ top_score_idx = scores.argmax(dim=0) # Get the index of the max score
192
+ top_score = scores[top_score_idx].item()
193
+ top_class_id = class_ids[top_score_idx].item()
194
+ top_box = boxes[top_score_idx].cpu().numpy()
195
+ scale_factor = 16
196
+ x1, y1, x2, y2 = top_box
197
+ x1, y1, x2, y2 = x1 * scale_factor, y1 * scale_factor, x2 * scale_factor, y2 * scale_factor
198
+ rect = patches.Rectangle(
199
+ (x1, y1), x2 - x1, y2 - y1,
200
+ linewidth=2, edgecolor='r', facecolor='none')
201
+ ax.add_patch(rect)
202
+
203
+ predicted_label = labels[top_class_id] # Map ID to label
204
+ ax.text(x1, y1, f"{predicted_label}: {top_score:.2f}",
205
+ color='r', fontsize=12, verticalalignment='top')
206
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
207
+
208
+ fig.canvas.draw() # Draw the canvas to make sure the image is rendered
209
+ image_array = np.array(fig.canvas.renderer.buffer_rgba()) # Convert to numpy array
210
+ print("____________image_arrya", image_array.shape)
211
+ image_resized = cv2.resize(image_array, (640, 640))
212
+ rgba_channels = cv2.split(image_resized)
213
+ alpha_channel = rgba_channels[3]
214
+ rgb_channels = np.stack(rgba_channels[:3], axis=-1)
215
+ #overlay_img = (alpha_channel[..., None] * image) + ((1 - alpha_channel[..., None]) * rgb_channels)
216
+
217
+ #temp = image_array.reshape((rgb_img_float.shape[0],rgb_img_float.shape[1]) )
218
+ #visualization = show_factorization_on_image(rgb_img_float, image_array.resize((rgb_img_float.shape)) , image_weight=0.3)
219
+ visualization = show_factorization_on_image(rgb_img_float, np.transpose(rgb_channels, (2, 0, 1)), image_weight=0.3)
220
+ results.append(visualization)
221
+ plt.clf()
222
+ #return image_array
223
+
224
+
225
+ return rgb_img_float, batch_explanations, results
226
+
227
+
228
+ def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):
229
+ for i, explanation in enumerate(batch_explanations):
230
+ # Create visualization for each explanation
231
+ print("visualization concepts",rgb_img_float.shape,explanation.shape )
232
+ visualization = show_factorization_on_image(rgb_img_float, explanation, image_weight=image_weight)
233
+ plt.figure()
234
+ plt.imshow(visualization) # Correctly pass the visualization data
235
+ plt.title(f'Explanation {i + 1}') # Set the title for each plot
236
+ plt.axis('off') # Hide axes
237
+ plt.show() # Show the plot
238
+ plt.savefig("test_w.png")
239
+ print('viz', visualization.shape)
240
+ return visualization