Spaces:
Running
Running
Commit
·
136ee37
1
Parent(s):
86915f5
Fix: Revert -12.
Browse files
app.py
CHANGED
@@ -10,7 +10,6 @@ from pytorch_grad_cam import EigenCAM
|
|
10 |
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
|
11 |
from PIL import Image
|
12 |
import gradio as gr
|
13 |
-
from ultralytics import YOLO
|
14 |
|
15 |
# Global Color Palette
|
16 |
COLORS = np.random.uniform(0, 255, size=(80, 3))
|
@@ -46,7 +45,7 @@ def load_yolo_model(version="yolov5"):
|
|
46 |
if version == "yolov5":
|
47 |
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
48 |
elif version == "yolov8":
|
49 |
-
model =
|
50 |
elif version == "yolov10":
|
51 |
model = torch.hub.load('ultralytics/yolov5', 'yolov5m', pretrained=True) # Placeholder for YOLOv10 (use an appropriate version if available)
|
52 |
else:
|
@@ -56,6 +55,7 @@ def load_yolo_model(version="yolov5"):
|
|
56 |
model.cpu()
|
57 |
return model
|
58 |
|
|
|
59 |
# Main function for Grad-CAM visualization
|
60 |
def process_image(image, yolo_versions=["yolov5"]):
|
61 |
image = np.array(image)
|
@@ -74,12 +74,7 @@ def process_image(image, yolo_versions=["yolov5"]):
|
|
74 |
for yolo_version in yolo_versions:
|
75 |
# Load the model based on YOLO version
|
76 |
model = load_yolo_model(yolo_version)
|
77 |
-
|
78 |
-
# YOLOv8: Extract last layer by model.model[-1] would not work; use the following:
|
79 |
-
if isinstance(model.model, torch.nn.Sequential):
|
80 |
-
target_layers = [model.model[-1]] # This assumes model layers are in a Sequential container
|
81 |
-
else:
|
82 |
-
target_layers = [model.model[-2]] # Use an appropriate layer
|
83 |
|
84 |
# Run YOLO detection
|
85 |
results = model([rgb_img])
|
@@ -105,21 +100,20 @@ def process_image(image, yolo_versions=["yolov5"]):
|
|
105 |
|
106 |
return result_images
|
107 |
|
108 |
-
# Gradio interface
|
109 |
interface = gr.Interface(
|
110 |
fn=process_image,
|
111 |
inputs=[
|
112 |
gr.Image(type="pil", label="Upload an Image"),
|
113 |
gr.CheckboxGroup(
|
114 |
-
choices=["yolov5",
|
115 |
value=["yolov5"], # Set the default value (YOLOv5 checked by default)
|
116 |
label="Select Model(s)",
|
117 |
)
|
118 |
],
|
119 |
-
outputs=gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
|
120 |
-
title="
|
121 |
description="XAI: Upload an image to visualize object detection of your models.."
|
122 |
)
|
123 |
|
124 |
if __name__ == "__main__":
|
125 |
-
interface.launch()
|
|
|
10 |
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
|
11 |
from PIL import Image
|
12 |
import gradio as gr
|
|
|
13 |
|
14 |
# Global Color Palette
|
15 |
COLORS = np.random.uniform(0, 255, size=(80, 3))
|
|
|
45 |
if version == "yolov5":
|
46 |
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
47 |
elif version == "yolov8":
|
48 |
+
model = torch.hub.load('ultralytics/yolov5:v7.0', 'yolov5', pretrained=True) # YOLOv8 is part of the yolov5 repo starting from v7.0
|
49 |
elif version == "yolov10":
|
50 |
model = torch.hub.load('ultralytics/yolov5', 'yolov5m', pretrained=True) # Placeholder for YOLOv10 (use an appropriate version if available)
|
51 |
else:
|
|
|
55 |
model.cpu()
|
56 |
return model
|
57 |
|
58 |
+
# Main function for Grad-CAM visualization
|
59 |
# Main function for Grad-CAM visualization
|
60 |
def process_image(image, yolo_versions=["yolov5"]):
|
61 |
image = np.array(image)
|
|
|
74 |
for yolo_version in yolo_versions:
|
75 |
# Load the model based on YOLO version
|
76 |
model = load_yolo_model(yolo_version)
|
77 |
+
target_layers = [model.model.model.model[-2]] # Assumes last layer is used for Grad-CAM
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
# Run YOLO detection
|
80 |
results = model([rgb_img])
|
|
|
100 |
|
101 |
return result_images
|
102 |
|
|
|
103 |
interface = gr.Interface(
|
104 |
fn=process_image,
|
105 |
inputs=[
|
106 |
gr.Image(type="pil", label="Upload an Image"),
|
107 |
gr.CheckboxGroup(
|
108 |
+
choices=["yolov5", "yolov8", "yolov10"],
|
109 |
value=["yolov5"], # Set the default value (YOLOv5 checked by default)
|
110 |
label="Select Model(s)",
|
111 |
)
|
112 |
],
|
113 |
+
outputs = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
|
114 |
+
title="Visualising the key image features that drive decisions with our explainable AI tool.",
|
115 |
description="XAI: Upload an image to visualize object detection of your models.."
|
116 |
)
|
117 |
|
118 |
if __name__ == "__main__":
|
119 |
+
interface.launch()
|