BhumikaMak commited on
Commit
4ddc91d
·
1 Parent(s): a33d7fc

Update: add support for other yolo variants

Browse files
Files changed (2) hide show
  1. app.py +56 -26
  2. requirements.txt +11 -7
app.py CHANGED
@@ -40,8 +40,27 @@ def draw_detections(boxes, colors, names, img):
40
  lineType=cv2.LINE_AA)
41
  return img
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Main function for Grad-CAM visualization
44
- def process_image(image):
45
  image = np.array(image)
46
  image = cv2.resize(image, (640, 640))
47
  rgb_img = image.copy()
@@ -51,41 +70,52 @@ def process_image(image):
51
  transform = transforms.ToTensor()
52
  tensor = transform(img_float).unsqueeze(0)
53
 
54
- # Load YOLOv5 model
55
- model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
56
- model.eval()
57
- model.cpu()
58
- target_layers = [model.model.model.model[-2]]
 
 
 
59
 
60
- # Run YOLO detection
61
- results = model([rgb_img])
62
- boxes, colors, names = parse_detections(results)
63
- detections_img = draw_detections(boxes, colors, names, rgb_img.copy())
64
 
65
- # Grad-CAM visualization
66
- cam = EigenCAM(model, target_layers)
67
- grayscale_cam = cam(tensor)[0, :, :]
68
- cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)
69
 
70
- # Renormalize Grad-CAM inside bounding boxes
71
- renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
72
- for x1, y1, x2, y2 in boxes:
73
- renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
74
- renormalized_cam = scale_cam_image(renormalized_cam)
75
- renormalized_cam_image = show_cam_on_image(img_float, renormalized_cam, use_rgb=True)
76
 
77
- # Concatenate images
78
- final_image = np.hstack((rgb_img, cam_image, renormalized_cam_image))
 
79
 
80
- return Image.fromarray(final_image)
81
 
82
  # Gradio Interface
83
  interface = gr.Interface(
84
  fn=process_image,
85
- inputs=gr.Image(type="pil", label="Upload an Image"),
86
- outputs=gr.Image(type="pil", label="Result"),
 
 
 
 
 
 
 
87
  title="Visualising the key image features that drive decisions with our explainable AI tool.",
88
- description="XAI: Upload an image to visualize object detection of your models."
89
  )
90
 
91
  if __name__ == "__main__":
 
40
  lineType=cv2.LINE_AA)
41
  return img
42
 
43
+ # Load the appropriate YOLO model based on the version
44
+ def load_yolo_model(version="yolov5"):
45
+ if version == "yolov3":
46
+ model = torch.hub.load('ultralytics/yolov3', 'yolov3', pretrained=True)
47
+ elif version == "yolov5":
48
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
49
+ elif version == "yolov7":
50
+ model = torch.hub.load('WongKinYiu/yolov7', 'yolov7', pretrained=True)
51
+ elif version == "yolov8":
52
+ model = torch.hub.load('ultralytics/yolov5:v7.0', 'yolov5', pretrained=True) # YOLOv8 is part of the yolov5 repo starting from v7.0
53
+ elif version == "yolov10":
54
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5m', pretrained=True) # Placeholder for YOLOv10 (use an appropriate version if available)
55
+ else:
56
+ raise ValueError(f"Unsupported YOLO version: {version}")
57
+
58
+ model.eval() # Set to evaluation mode
59
+ model.cpu()
60
+ return model
61
+
62
  # Main function for Grad-CAM visualization
63
+ def process_image(image, yolo_versions=["yolov5"]):
64
  image = np.array(image)
65
  image = cv2.resize(image, (640, 640))
66
  rgb_img = image.copy()
 
70
  transform = transforms.ToTensor()
71
  tensor = transform(img_float).unsqueeze(0)
72
 
73
+ # Initialize list to store result images
74
+ result_images = []
75
+
76
+ # Process each selected YOLO model
77
+ for yolo_version in yolo_versions:
78
+ # Load the model based on YOLO version
79
+ model = load_yolo_model(yolo_version)
80
+ target_layers = [model.model.model.model[-2]] # Assumes last layer is used for Grad-CAM
81
 
82
+ # Run YOLO detection
83
+ results = model([rgb_img])
84
+ boxes, colors, names = parse_detections(results)
85
+ detections_img = draw_detections(boxes, colors, names, rgb_img.copy())
86
 
87
+ # Grad-CAM visualization
88
+ cam = EigenCAM(model, target_layers)
89
+ grayscale_cam = cam(tensor)[0, :, :]
90
+ cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)
91
 
92
+ # Renormalize Grad-CAM inside bounding boxes
93
+ renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
94
+ for x1, y1, x2, y2 in boxes:
95
+ renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
96
+ renormalized_cam = scale_cam_image(renormalized_cam)
97
+ renormalized_cam_image = show_cam_on_image(img_float, renormalized_cam, use_rgb=True)
98
 
99
+ # Concatenate images
100
+ final_image = np.hstack((rgb_img, cam_image, renormalized_cam_image))
101
+ result_images.append((yolo_version, Image.fromarray(final_image)))
102
 
103
+ return result_images
104
 
105
  # Gradio Interface
106
  interface = gr.Interface(
107
  fn=process_image,
108
+ inputs=[
109
+ gr.Image(type="pil", label="Upload an Image"),
110
+ gr.CheckboxGroup(
111
+ choices=["yolov3", "yolov5", "yolov7", "yolov8", "yolov10"],
112
+ label="Select YOLO Models",
113
+ default=["yolov5"]
114
+ )
115
+ ],
116
+ outputs=gr.Gallery(label="Results", elem_id="gallery").style(grid=[2], height=500),
117
  title="Visualising the key image features that drive decisions with our explainable AI tool.",
118
+ description="XAI: Upload an image to visualize object detection of your models.."
119
  )
120
 
121
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,8 +1,12 @@
1
- torch
2
- torchvision
3
- torchaudio
4
- numpy
5
- pillow
6
- opencv-python
7
  git+https://github.com/jacobgil/pytorch-grad-cam.git
8
- gradio
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.15.0
3
+ torchaudio==2.1.0
4
+ numpy==1.23.4
5
+ pillow==9.3.0
6
+ opencv-python==4.6.0.66
7
  git+https://github.com/jacobgil/pytorch-grad-cam.git
8
+ gradio==3.28.2
9
+ git+https://github.com/ultralytics/yolov5.git # For YOLOv5
10
+ git+https://github.com/WongKinYiu/yolov7.git # For YOLOv7
11
+ git+https://github.com/ultralytics/ultralytics.git # For YOLOv8
12
+ git+https://github.com/saeedanwar/yolov10.git # For YOLOv10