BhumikaMak commited on
Commit
b63af6d
·
1 Parent(s): fd2244b

Update: support for yolo-explainer

Browse files
Files changed (1) hide show
  1. app.py +37 -8
app.py CHANGED
@@ -10,6 +10,7 @@ from pytorch_grad_cam import EigenCAM
10
  from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
11
  from PIL import Image
12
  import gradio as gr
 
13
 
14
  # Global Color Palette
15
  COLORS = np.random.uniform(0, 255, size=(80, 3))
@@ -42,8 +43,16 @@ def draw_detections(boxes, colors, names, img):
42
 
43
  # Load the appropriate YOLO model based on the version
44
  def load_yolo_model(version="yolov5"):
45
- if version == "yolov5":
 
 
46
  model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
 
 
 
 
 
 
47
  else:
48
  raise ValueError(f"Unsupported YOLO version: {version}")
49
 
@@ -51,7 +60,8 @@ def load_yolo_model(version="yolov5"):
51
  model.cpu()
52
  return model
53
 
54
- def process_image(image, yolo_versions=["yolov5"]):
 
55
  image = np.array(image)
56
  image = cv2.resize(image, (640, 640))
57
  rgb_img = image.copy()
@@ -66,6 +76,23 @@ def process_image(image, yolo_versions=["yolov5"]):
66
 
67
  # Process each selected YOLO model
68
  for yolo_version in yolo_versions:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # Load the model based on YOLO version
70
  model = load_yolo_model(yolo_version)
71
  target_layers = [model.model.model.model[-2]] # Assumes last layer is used for Grad-CAM
@@ -94,20 +121,22 @@ def process_image(image, yolo_versions=["yolov5"]):
94
 
95
  return result_images
96
 
 
97
  interface = gr.Interface(
98
  fn=process_image,
99
  inputs=[
100
  gr.Image(type="pil", label="Upload an Image"),
101
  gr.CheckboxGroup(
102
- choices=["yolov5"],
103
- value=["yolov5"], # Set the default value (YOLOv5 checked by default)
104
  label="Select Model(s)",
105
- )
 
106
  ],
107
- outputs = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
108
  title="Visualising the key image features that drive decisions with our explainable AI tool.",
109
- description="XAI: Upload an image to visualize object detection of your models.."
110
  )
111
 
112
  if __name__ == "__main__":
113
- interface.launch()
 
10
  from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
11
  from PIL import Image
12
  import gradio as gr
13
+ from YOLOv8_Explainer import yolov8_heatmap, display_images # Import Explainer
14
 
15
  # Global Color Palette
16
  COLORS = np.random.uniform(0, 255, size=(80, 3))
 
43
 
44
  # Load the appropriate YOLO model based on the version
45
  def load_yolo_model(version="yolov5"):
46
+ if version == "yolov3":
47
+ model = torch.hub.load('ultralytics/yolov3', 'yolov3', pretrained=True)
48
+ elif version == "yolov5":
49
  model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
50
+ elif version == "yolov7":
51
+ model = torch.hub.load('WongKinYiu/yolov7', 'yolov7', pretrained=True)
52
+ elif version == "yolov8":
53
+ model = torch.hub.load('ultralytics/yolov5:v7.0', 'yolov5', pretrained=True)
54
+ elif version == "yolov10":
55
+ model = torch.hub.load('ultralytics/yolov5', 'yolov5m', pretrained=True)
56
  else:
57
  raise ValueError(f"Unsupported YOLO version: {version}")
58
 
 
60
  model.cpu()
61
  return model
62
 
63
+ # Main function for Grad-CAM visualization
64
+ def process_image(image, yolo_versions=["yolov5"], use_explainer=False):
65
  image = np.array(image)
66
  image = cv2.resize(image, (640, 640))
67
  rgb_img = image.copy()
 
76
 
77
  # Process each selected YOLO model
78
  for yolo_version in yolo_versions:
79
+ if use_explainer and yolo_version == "yolov8":
80
+ # Use YOLOv8 Explainer for EigenCAM heatmap
81
+ explainer_model = yolov8_heatmap(
82
+ weight="yolov8n.pt",
83
+ conf_threshold=0.4,
84
+ device="cpu",
85
+ method="EigenCAM",
86
+ layer=[10, 12, 14, 16, 18, -3],
87
+ backward_type="all",
88
+ ratio=0.02,
89
+ show_box=True,
90
+ renormalize=False,
91
+ )
92
+ imagelist = explainer_model(img_path=image)
93
+ display_images(imagelist)
94
+ continue # Skip Grad-CAM for this case
95
+
96
  # Load the model based on YOLO version
97
  model = load_yolo_model(yolo_version)
98
  target_layers = [model.model.model.model[-2]] # Assumes last layer is used for Grad-CAM
 
121
 
122
  return result_images
123
 
124
+ # Gradio Interface
125
  interface = gr.Interface(
126
  fn=process_image,
127
  inputs=[
128
  gr.Image(type="pil", label="Upload an Image"),
129
  gr.CheckboxGroup(
130
+ choices=["yolov3", "yolov5", "yolov7", "yolov8", "yolov10"],
131
+ value=["yolov5"], # Default to YOLOv5
132
  label="Select Model(s)",
133
+ ),
134
+ gr.Checkbox(label="Use YOLOv8 Explainer?", value=False)
135
  ],
136
+ outputs=gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
137
  title="Visualising the key image features that drive decisions with our explainable AI tool.",
138
+ description="XAI: Upload an image to visualize object detection of your models."
139
  )
140
 
141
  if __name__ == "__main__":
142
+ interface.launch()