BhumikaMak commited on
Commit
46964fa
·
verified ·
1 Parent(s): a4de2b4

debug: dff_nmf incorrect factorisation output

Browse files
Files changed (1) hide show
  1. yolov8.py +29 -42
yolov8.py CHANGED
@@ -174,48 +174,35 @@ def dff_nmf(image, target_lyr, n_components):
174
  computation_on_concepts=None)
175
 
176
  concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
177
-
178
- # Getting predictions directly from YOLO
179
- with torch.no_grad():
180
- results = model(input_tensor)
181
-
182
- # Assuming results is a list, extract the first element
183
- detections = results[0] # The first element should contain the detection data
184
-
185
- # Access detection results
186
- #boxes = detections.boxes.xyxy.cpu().numpy() # Bounding box coordinates (xyxy)
187
- #probs = detections.probs.cpu().numpy() # Confidence scores (probabilities)
188
- #classes = detections.classes.cpu().numpy() # Class IDs
189
-
190
- # Filter detections with confidence score > threshold (e.g., 0.5)
191
- #high_conf_indices = probs > 0.5
192
- #high_conf_boxes = boxes[high_conf_indices]
193
- #high_conf_classes = classes[high_conf_indices]
194
- #high_conf_probs = probs[high_conf_indices]
195
-
196
- # Example visualization and output processing
197
- fig, ax = plt.subplots(1, figsize=(8, 8))
198
- ax.axis("off")
199
- ax.imshow(rgb_img_float)
200
-
201
- #for box, cls, prob in zip(high_conf_boxes, high_conf_classes, high_conf_probs):
202
- # x1, y1, x2, y2 = box
203
- # rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
204
- # linewidth=2, edgecolor='r', facecolor='none')
205
- # ax.add_patch(rect)
206
- # ax.text(x1, y1, f"Class {cls}, Prob {prob:.2f}", color='r', fontsize=12, verticalalignment='top')
207
-
208
- plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
209
- fig.canvas.draw()
210
- image_array = np.array(fig.canvas.renderer.buffer_rgba())
211
- image_resized = cv2.resize(image_array, (640, 640))
212
- rgba_channels = cv2.split(image_resized)
213
- alpha_channel = rgba_channels[3]
214
- rgb_channels = np.stack(rgba_channels[:3], axis=-1)
215
-
216
- visualization = show_factorization_on_image(rgb_img_float, np.transpose(rgb_channels, (2, 0, 1)), image_weight=0.3)
217
-
218
- return rgb_img_float, batch_explanations, visualization
219
 
220
 
221
  def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):
 
174
  computation_on_concepts=None)
175
 
176
  concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
177
+ print("#################shapes###############")
178
+ print(concepts.shape, batch_explanations.shape, explanations.shape)
179
+ results = []
180
+ for indx in range(explanations[0].shape[0]):
181
+ upsampled_input = explanations[0][indx]
182
+ upsampled_input = torch.tensor(upsampled_input)
183
+ device = next(model.parameters()).device
184
+ input_tensor = upsampled_input.unsqueeze(0)
185
+ input_tensor = input_tensor.unsqueeze(1).repeat(1, 128, 1, 1)
186
+ fig, ax = plt.subplots(1, figsize=(8, 8))
187
+ ax.axis("off")
188
+ ax.imshow(torch.tensor(batch_explanations[0][indx]).cpu().numpy(), cmap="plasma") # Display i
189
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
190
+ fig.canvas.draw() # Draw the canvas to make sure the image is rendered
191
+ image_array = np.array(fig.canvas.renderer.buffer_rgba()) # Convert to numpy array
192
+ print("____________image_arrya", image_array.shape)
193
+ image_resized = cv2.resize(image_array, (640, 640))
194
+ rgba_channels = cv2.split(image_resized)
195
+ alpha_channel = rgba_channels[3]
196
+ rgb_channels = np.stack(rgba_channels[:3], axis=-1)
197
+ #overlay_img = (alpha_channel[..., None] * image) + ((1 - alpha_channel[..., None]) * rgb_channels)
198
+
199
+ #temp = image_array.reshape((rgb_img_float.shape[0],rgb_img_float.shape[1]) )
200
+ #visualization = show_factorization_on_image(rgb_img_float, image_array.resize((rgb_img_float.shape)) , image_weight=0.3)
201
+ visualization = show_factorization_on_image(rgb_img_float, np.transpose(rgb_channels, (2, 0, 1)), image_weight=0.3)
202
+ results.append(visualization)
203
+ plt.clf()
204
+
205
+ return rgb_img_float, batch_explanations, results
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):