Spaces:
Sleeping
Sleeping
Commit
·
34678a5
1
Parent(s):
71b8b5d
Add: support for yolov8
Browse files
yolov8.py
CHANGED
@@ -35,23 +35,26 @@ def draw_detections(boxes, colors, names, img):
|
|
35 |
|
36 |
|
37 |
def generate_cam_image(model, target_layers, tensor, rgb_img, boxes):
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
grayscale_cam = cam(tensor)[0, :, :]
|
40 |
img_float = np.float32(rgb_img) / 255
|
41 |
-
|
42 |
-
# Generate Grad-CAM
|
43 |
cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)
|
44 |
-
|
45 |
-
# Renormalize Grad-CAM inside bounding boxes
|
46 |
renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
|
47 |
for x1, y1, x2, y2 in boxes:
|
48 |
renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
|
49 |
renormalized_cam = scale_cam_image(renormalized_cam)
|
50 |
renormalized_cam_image = show_cam_on_image(img_float, renormalized_cam, use_rgb=True)
|
51 |
-
|
52 |
return cam_image, renormalized_cam_image
|
53 |
|
54 |
-
|
55 |
def xai_yolov8(image):
|
56 |
# Load YOLOv8 model
|
57 |
model = YOLO('yolov8n.pt') # Load YOLOv8 nano model
|
|
|
35 |
|
36 |
|
37 |
def generate_cam_image(model, target_layers, tensor, rgb_img, boxes):
|
38 |
+
class YOLOWrapper(torch.nn.Module):
|
39 |
+
def __init__(self, model):
|
40 |
+
super(YOLOWrapper, self).__init__()
|
41 |
+
self.model = model
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return self.model.model.forward_once(x) # Ensure correct layer is called
|
45 |
+
|
46 |
+
wrapped_model = YOLOWrapper(model)
|
47 |
+
cam = EigenCAM(wrapped_model, target_layers)
|
48 |
grayscale_cam = cam(tensor)[0, :, :]
|
49 |
img_float = np.float32(rgb_img) / 255
|
|
|
|
|
50 |
cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)
|
|
|
|
|
51 |
renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
|
52 |
for x1, y1, x2, y2 in boxes:
|
53 |
renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
|
54 |
renormalized_cam = scale_cam_image(renormalized_cam)
|
55 |
renormalized_cam_image = show_cam_on_image(img_float, renormalized_cam, use_rgb=True)
|
|
|
56 |
return cam_image, renormalized_cam_image
|
57 |
|
|
|
58 |
def xai_yolov8(image):
|
59 |
# Load YOLOv8 model
|
60 |
model = YOLO('yolov8n.pt') # Load YOLOv8 nano model
|