BhumikaMak commited on
Commit
34678a5
·
1 Parent(s): 71b8b5d

Add: support for yolov8

Browse files
Files changed (1) hide show
  1. yolov8.py +10 -7
yolov8.py CHANGED
@@ -35,23 +35,26 @@ def draw_detections(boxes, colors, names, img):
35
 
36
 
37
  def generate_cam_image(model, target_layers, tensor, rgb_img, boxes):
38
- cam = EigenCAM(model, target_layers)
 
 
 
 
 
 
 
 
 
39
  grayscale_cam = cam(tensor)[0, :, :]
40
  img_float = np.float32(rgb_img) / 255
41
-
42
- # Generate Grad-CAM
43
  cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)
44
-
45
- # Renormalize Grad-CAM inside bounding boxes
46
  renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
47
  for x1, y1, x2, y2 in boxes:
48
  renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
49
  renormalized_cam = scale_cam_image(renormalized_cam)
50
  renormalized_cam_image = show_cam_on_image(img_float, renormalized_cam, use_rgb=True)
51
-
52
  return cam_image, renormalized_cam_image
53
 
54
-
55
  def xai_yolov8(image):
56
  # Load YOLOv8 model
57
  model = YOLO('yolov8n.pt') # Load YOLOv8 nano model
 
35
 
36
 
37
  def generate_cam_image(model, target_layers, tensor, rgb_img, boxes):
38
+ class YOLOWrapper(torch.nn.Module):
39
+ def __init__(self, model):
40
+ super(YOLOWrapper, self).__init__()
41
+ self.model = model
42
+
43
+ def forward(self, x):
44
+ return self.model.model.forward_once(x) # Ensure correct layer is called
45
+
46
+ wrapped_model = YOLOWrapper(model)
47
+ cam = EigenCAM(wrapped_model, target_layers)
48
  grayscale_cam = cam(tensor)[0, :, :]
49
  img_float = np.float32(rgb_img) / 255
 
 
50
  cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)
 
 
51
  renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
52
  for x1, y1, x2, y2 in boxes:
53
  renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
54
  renormalized_cam = scale_cam_image(renormalized_cam)
55
  renormalized_cam_image = show_cam_on_image(img_float, renormalized_cam, use_rgb=True)
 
56
  return cam_image, renormalized_cam_image
57
 
 
58
  def xai_yolov8(image):
59
  # Load YOLOv8 model
60
  model = YOLO('yolov8n.pt') # Load YOLOv8 nano model