BhumikaMak commited on
Commit
9101eba
·
verified ·
1 Parent(s): e33c04c

update: dff_nmf

Browse files
Files changed (1) hide show
  1. yolov8.py +39 -65
yolov8.py CHANGED
@@ -162,7 +162,6 @@ class DeepFeatureFactorization:
162
 
163
 
164
 
165
-
166
  def dff_nmf(image, target_lyr, n_components):
167
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
  mean = [0.485, 0.456, 0.406] # Mean for RGB channels
@@ -171,77 +170,52 @@ def dff_nmf(image, target_lyr, n_components):
171
  rgb_img_float = np.float32(img) / 255.0
172
  input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
173
 
174
- model = YOLO('yolov8s.pt')
175
- dff= DeepFeatureFactorization(model=model,
176
- target_layer=model.model.model[int(target_lyr)],
177
- computation_on_concepts=None)
178
 
179
  concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
180
 
 
 
 
 
 
 
 
 
 
181
 
182
- #yolov5_categories_url = \
183
- # "https://github.com/ultralytics/yolov5/raw/master/data/coco128.yaml" # URL to the YOLOv5 categories file
184
- #yaml_data = requests.get(yolov5_categories_url).text
185
- # labels = yaml.safe_load(yaml_data)['names'] # Parse the YAML file to get class names
186
- num_classes = model.model.model[-1].nc
187
- results = []
188
- for indx in range(explanations[0].shape[0]):
189
- upsampled_input = explanations[0][indx]
190
- upsampled_input = torch.tensor(upsampled_input)
191
- device = next(model.parameters()).device
192
- input_tensor = upsampled_input.unsqueeze(0)
193
- input_tensor = input_tensor.unsqueeze(1).repeat(1, 128, 1, 1)
194
- detection_lyr = model.model.model[-1]
195
- output1 = detection_lyr.m[0](input_tensor.to(device))
196
- objectness = output1[..., 4] # Objectness score (index 4)
197
- class_scores = output1[..., 5:] # Class scores (from index 5 onwards, representing 80 classes)
198
- objectness = torch.sigmoid(objectness)
199
- class_scores = torch.sigmoid(class_scores)
200
- confidence_mask = objectness > 0.5
201
- objectness = objectness[confidence_mask]
202
- class_scores = class_scores[confidence_mask]
203
- scores, class_ids = class_scores.max(dim=-1) # Get max class score per cell
204
- scores = scores * objectness # Adjust scores by objectness
205
- boxes = output1[..., :4] # First 4 values are x1, y1, x2, y2
206
- boxes = boxes[confidence_mask] # Filter boxes by confidence mask
207
- fig, ax = plt.subplots(1, figsize=(8, 8))
208
- ax.axis("off")
209
- ax.imshow(torch.tensor(batch_explanations[0][indx]).cpu().numpy(), cmap="plasma") # Display image
210
- top_score_idx = scores.argmax(dim=0) # Get the index of the max score
211
- top_score = scores[top_score_idx].item()
212
- top_class_id = class_ids[top_score_idx].item()
213
- top_box = boxes[top_score_idx].cpu().numpy()
214
- scale_factor = 16
215
- x1, y1, x2, y2 = top_box
216
- x1, y1, x2, y2 = x1 * scale_factor, y1 * scale_factor, x2 * scale_factor, y2 * scale_factor
217
- rect = patches.Rectangle(
218
- (x1, y1), x2 - x1, y2 - y1,
219
- linewidth=2, edgecolor='r', facecolor='none')
220
  ax.add_patch(rect)
 
 
 
 
 
 
 
 
 
221
 
222
- #predicted_label = labels[top_class_id] # Map ID to label
223
- #ax.text(x1, y1, f"{predicted_label}: {top_score:.2f}",
224
- # color='r', fontsize=12, verticalalignment='top')
225
- plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
226
-
227
- fig.canvas.draw() # Draw the canvas to make sure the image is rendered
228
- image_array = np.array(fig.canvas.renderer.buffer_rgba()) # Convert to numpy array
229
- print("____________image_arrya", image_array.shape)
230
- image_resized = cv2.resize(image_array, (640, 640))
231
- rgba_channels = cv2.split(image_resized)
232
- alpha_channel = rgba_channels[3]
233
- rgb_channels = np.stack(rgba_channels[:3], axis=-1)
234
- #overlay_img = (alpha_channel[..., None] * image) + ((1 - alpha_channel[..., None]) * rgb_channels)
235
-
236
- #temp = image_array.reshape((rgb_img_float.shape[0],rgb_img_float.shape[1]) )
237
- #visualization = show_factorization_on_image(rgb_img_float, image_array.resize((rgb_img_float.shape)) , image_weight=0.3)
238
- visualization = show_factorization_on_image(rgb_img_float, np.transpose(rgb_channels, (2, 0, 1)), image_weight=0.3)
239
- results.append(visualization)
240
- plt.clf()
241
- #return image_array
242
-
243
 
244
- return rgb_img_float, batch_explanations, results
245
 
246
 
247
  def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):
 
162
 
163
 
164
 
 
165
  def dff_nmf(image, target_lyr, n_components):
166
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
167
  mean = [0.485, 0.456, 0.406] # Mean for RGB channels
 
170
  rgb_img_float = np.float32(img) / 255.0
171
  input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
172
 
173
+ model = YOLO('yolov8s.pt') # Ensure the model is loaded correctly
174
+ dff = DeepFeatureFactorization(model=model,
175
+ target_layer=model.model.model[int(target_lyr)],
176
+ computation_on_concepts=None)
177
 
178
  concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
179
 
180
+ # Getting predictions directly from YOLO
181
+ with torch.no_grad():
182
+ results = model(input_tensor)
183
+
184
+ # Post-processing to extract detections
185
+ boxes, scores, classes = results.xywh[0][:, :4], results.xywh[0][:, 4], results.xywh[0][:, 5]
186
+ boxes = boxes.cpu().numpy()
187
+ scores = scores.cpu().numpy()
188
+ classes = classes.cpu().numpy()
189
 
190
+ # Filter detections with confidence score > threshold (e.g., 0.5)
191
+ high_conf_boxes = boxes[scores > 0.5]
192
+ high_conf_classes = classes[scores > 0.5]
193
+
194
+ # Use the processed detections for visualization and further tasks
195
+ # Example visualization and output processing
196
+ fig, ax = plt.subplots(1, figsize=(8, 8))
197
+ ax.axis("off")
198
+ ax.imshow(rgb_img_float)
199
+
200
+ for box, cls in zip(high_conf_boxes, high_conf_classes):
201
+ x1, y1, x2, y2 = box
202
+ rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
203
+ linewidth=2, edgecolor='r', facecolor='none')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  ax.add_patch(rect)
205
+ ax.text(x1, y1, f"Class {cls}", color='r', fontsize=12, verticalalignment='top')
206
+
207
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
208
+ fig.canvas.draw()
209
+ image_array = np.array(fig.canvas.renderer.buffer_rgba())
210
+ image_resized = cv2.resize(image_array, (640, 640))
211
+ rgba_channels = cv2.split(image_resized)
212
+ alpha_channel = rgba_channels[3]
213
+ rgb_channels = np.stack(rgba_channels[:3], axis=-1)
214
 
215
+ visualization = show_factorization_on_image(rgb_img_float, np.transpose(rgb_channels, (2, 0, 1)), image_weight=0.3)
216
+
217
+ return rgb_img_float, batch_explanations, visualization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
 
219
 
220
 
221
  def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):