BhumikaMak commited on
Commit
04535ac
·
verified ·
1 Parent(s): f72bf07

update: dff_nmf depreciation of 'scores'

Browse files
Files changed (1) hide show
  1. yolov8.py +7 -7
yolov8.py CHANGED
@@ -160,7 +160,6 @@ class DeepFeatureFactorization:
160
  f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}")
161
  return True
162
 
163
-
164
  def dff_nmf(image, target_lyr, n_components):
165
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
166
  mean = [0.485, 0.456, 0.406] # Mean for RGB channels
@@ -185,24 +184,26 @@ def dff_nmf(image, target_lyr, n_components):
185
 
186
  # Access detection results
187
  boxes = detections.boxes.xyxy.cpu().numpy() # Bounding box coordinates (xyxy)
188
- scores = detections.scores.cpu().numpy() # Confidence scores
189
  classes = detections.classes.cpu().numpy() # Class IDs
190
 
191
  # Filter detections with confidence score > threshold (e.g., 0.5)
192
- high_conf_boxes = boxes[scores > 0.5]
193
- high_conf_classes = classes[scores > 0.5]
 
 
194
 
195
  # Example visualization and output processing
196
  fig, ax = plt.subplots(1, figsize=(8, 8))
197
  ax.axis("off")
198
  ax.imshow(rgb_img_float)
199
 
200
- for box, cls in zip(high_conf_boxes, high_conf_classes):
201
  x1, y1, x2, y2 = box
202
  rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
203
  linewidth=2, edgecolor='r', facecolor='none')
204
  ax.add_patch(rect)
205
- ax.text(x1, y1, f"Class {cls}", color='r', fontsize=12, verticalalignment='top')
206
 
207
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
208
  fig.canvas.draw()
@@ -217,7 +218,6 @@ def dff_nmf(image, target_lyr, n_components):
217
  return rgb_img_float, batch_explanations, visualization
218
 
219
 
220
-
221
  def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):
222
  for i, explanation in enumerate(batch_explanations):
223
  # Create visualization for each explanation
 
160
  f"An exception occurred in ActivationSummary with block: {exc_type}. Message: {exc_value}")
161
  return True
162
 
 
163
  def dff_nmf(image, target_lyr, n_components):
164
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
165
  mean = [0.485, 0.456, 0.406] # Mean for RGB channels
 
184
 
185
  # Access detection results
186
  boxes = detections.boxes.xyxy.cpu().numpy() # Bounding box coordinates (xyxy)
187
+ probs = detections.probs.cpu().numpy() # Confidence scores (probabilities)
188
  classes = detections.classes.cpu().numpy() # Class IDs
189
 
190
  # Filter detections with confidence score > threshold (e.g., 0.5)
191
+ high_conf_indices = probs > 0.5
192
+ high_conf_boxes = boxes[high_conf_indices]
193
+ high_conf_classes = classes[high_conf_indices]
194
+ high_conf_probs = probs[high_conf_indices]
195
 
196
  # Example visualization and output processing
197
  fig, ax = plt.subplots(1, figsize=(8, 8))
198
  ax.axis("off")
199
  ax.imshow(rgb_img_float)
200
 
201
+ for box, cls, prob in zip(high_conf_boxes, high_conf_classes, high_conf_probs):
202
  x1, y1, x2, y2 = box
203
  rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
204
  linewidth=2, edgecolor='r', facecolor='none')
205
  ax.add_patch(rect)
206
+ ax.text(x1, y1, f"Class {cls}, Prob {prob:.2f}", color='r', fontsize=12, verticalalignment='top')
207
 
208
  plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
209
  fig.canvas.draw()
 
218
  return rgb_img_float, batch_explanations, visualization
219
 
220
 
 
221
  def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):
222
  for i, explanation in enumerate(batch_explanations):
223
  # Create visualization for each explanation