Spaces:
Sleeping
Sleeping
Commit
·
c8e3f60
1
Parent(s):
ef20834
Fix: Revert -5.
Browse files
app.py
CHANGED
@@ -45,9 +45,13 @@ def load_yolo_model(version="yolov5"):
|
|
45 |
if version == "yolov3":
|
46 |
model = torch.hub.load('ultralytics/yolov3', 'yolov3', pretrained=True)
|
47 |
elif version == "yolov5":
|
48 |
-
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
|
|
|
|
49 |
elif version == "yolov8":
|
50 |
-
model = torch.hub.load('ultralytics/yolov5', '
|
|
|
|
|
51 |
else:
|
52 |
raise ValueError(f"Unsupported YOLO version: {version}")
|
53 |
|
@@ -55,6 +59,7 @@ def load_yolo_model(version="yolov5"):
|
|
55 |
model.cpu()
|
56 |
return model
|
57 |
|
|
|
58 |
# Main function for Grad-CAM visualization
|
59 |
def process_image(image, yolo_versions=["yolov5"]):
|
60 |
image = np.array(image)
|
@@ -73,12 +78,10 @@ def process_image(image, yolo_versions=["yolov5"]):
|
|
73 |
for yolo_version in yolo_versions:
|
74 |
# Load the model based on YOLO version
|
75 |
model = load_yolo_model(yolo_version)
|
76 |
-
|
77 |
-
|
78 |
-
target_layers = [model.model.model[-2]] # Assuming last layer of the model is a Conv layer
|
79 |
-
|
80 |
# Run YOLO detection
|
81 |
-
results = model(rgb_img)
|
82 |
boxes, colors, names = parse_detections(results)
|
83 |
detections_img = draw_detections(boxes, colors, names, rgb_img.copy())
|
84 |
|
@@ -101,21 +104,20 @@ def process_image(image, yolo_versions=["yolov5"]):
|
|
101 |
|
102 |
return result_images
|
103 |
|
104 |
-
# Define the Gradio interface
|
105 |
interface = gr.Interface(
|
106 |
fn=process_image,
|
107 |
inputs=[
|
108 |
gr.Image(type="pil", label="Upload an Image"),
|
109 |
gr.CheckboxGroup(
|
110 |
-
choices=["yolov3", "yolov5", "yolov8"],
|
111 |
value=["yolov5"], # Set the default value (YOLOv5 checked by default)
|
112 |
label="Select Model(s)",
|
113 |
)
|
114 |
],
|
115 |
outputs = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
|
116 |
title="Visualising the key image features that drive decisions with our explainable AI tool.",
|
117 |
-
description="XAI: Upload an image to visualize object detection of your models
|
118 |
)
|
119 |
|
120 |
if __name__ == "__main__":
|
121 |
-
interface.launch()
|
|
|
45 |
if version == "yolov3":
|
46 |
model = torch.hub.load('ultralytics/yolov3', 'yolov3', pretrained=True)
|
47 |
elif version == "yolov5":
|
48 |
+
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
|
49 |
+
elif version == "yolov7":
|
50 |
+
model = torch.hub.load('WongKinYiu/yolov7', 'yolov7', pretrained=True)
|
51 |
elif version == "yolov8":
|
52 |
+
model = torch.hub.load('ultralytics/yolov5:v7.0', 'yolov5', pretrained=True) # YOLOv8 is part of the yolov5 repo starting from v7.0
|
53 |
+
elif version == "yolov10":
|
54 |
+
model = torch.hub.load('ultralytics/yolov5', 'yolov5m', pretrained=True) # Placeholder for YOLOv10 (use an appropriate version if available)
|
55 |
else:
|
56 |
raise ValueError(f"Unsupported YOLO version: {version}")
|
57 |
|
|
|
59 |
model.cpu()
|
60 |
return model
|
61 |
|
62 |
+
# Main function for Grad-CAM visualization
|
63 |
# Main function for Grad-CAM visualization
|
64 |
def process_image(image, yolo_versions=["yolov5"]):
|
65 |
image = np.array(image)
|
|
|
78 |
for yolo_version in yolo_versions:
|
79 |
# Load the model based on YOLO version
|
80 |
model = load_yolo_model(yolo_version)
|
81 |
+
target_layers = [model.model.model.model[-2]] # Assumes last layer is used for Grad-CAM
|
82 |
+
|
|
|
|
|
83 |
# Run YOLO detection
|
84 |
+
results = model([rgb_img])
|
85 |
boxes, colors, names = parse_detections(results)
|
86 |
detections_img = draw_detections(boxes, colors, names, rgb_img.copy())
|
87 |
|
|
|
104 |
|
105 |
return result_images
|
106 |
|
|
|
107 |
interface = gr.Interface(
|
108 |
fn=process_image,
|
109 |
inputs=[
|
110 |
gr.Image(type="pil", label="Upload an Image"),
|
111 |
gr.CheckboxGroup(
|
112 |
+
choices=["yolov3", "yolov5", "yolov7", "yolov8", "yolov10"],
|
113 |
value=["yolov5"], # Set the default value (YOLOv5 checked by default)
|
114 |
label="Select Model(s)",
|
115 |
)
|
116 |
],
|
117 |
outputs = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
|
118 |
title="Visualising the key image features that drive decisions with our explainable AI tool.",
|
119 |
+
description="XAI: Upload an image to visualize object detection of your models.."
|
120 |
)
|
121 |
|
122 |
if __name__ == "__main__":
|
123 |
+
interface.launch()
|