File size: 5,487 Bytes
60af537
 
 
 
 
 
 
 
 
 
 
 
82b219b
53bfd89
60af537
 
 
 
c629158
c7e11d7
c629158
 
4795cf5
 
 
 
 
 
 
 
 
 
c629158
4795cf5
 
 
 
 
 
 
 
 
c629158
60af537
 
 
 
 
 
 
 
 
 
 
 
4ddc91d
 
d2d1a78
c8e3f60
4ddc91d
82b219b
c8e3f60
 
4ddc91d
 
 
 
 
 
 
136ee37
60af537
4ddc91d
60af537
 
 
 
 
 
 
 
 
f8576f8
4ddc91d
 
 
 
 
 
2cd21c1
 
 
ad21b0c
4ddc91d
c8e3f60
c629158
4ddc91d
60af537
4ddc91d
 
 
 
60af537
4ddc91d
 
 
 
 
 
60af537
f8576f8
4ddc91d
f8576f8
 
60af537
4ddc91d
60af537
 
 
4ddc91d
 
 
136ee37
0c25380
9acd672
4ddc91d
 
136ee37
 
c8e3f60
60af537
 
 
136ee37
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

import torch
import cv2
import numpy as np
import torchvision.transforms as transforms
from pytorch_grad_cam import EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
from PIL import Image
import gradio as gr
from ultralytics import YOLO

# Global Color Palette
COLORS = np.random.uniform(0, 255, size=(80, 3))

# Function to parse YOLO detections
def parse_detections(results, yolo_version):
    boxes, colors, names = [], [], []
    if yolo_version == "yolov5":
        detections = results.pandas().xyxy[0].to_dict()
        for i in range(len(detections["xmin"])):
            confidence = detections["confidence"][i]
            if confidence < 0.2:
                continue
            xmin, ymin = int(detections["xmin"][i]), int(detections["ymin"][i])
            xmax, ymax = int(detections["xmax"][i]), int(detections["ymax"][i])
            name, category = detections["name"][i], int(detections["class"][i])
            boxes.append((xmin, ymin, xmax, ymax))
            colors.append(COLORS[category])
            names.append(name)
    else:
        boxes.append(results[0].boxes.xyxy)  # Bounding boxes in xyxy format (x1, y1, x2, y2)
        confidences = results[0].boxes.conf  # Confidence scores
        class_ids = results[0].boxes.cls  # Class IDs
        names.append(results[0].names)  # Class names (from model)
        # Append predefined color based on category (class ID)
        for class_id in class_ids:
            # Map class ID to a color from the COLORS list (make sure you have enough colors)
            color = COLORS[int(class_id) % len(COLORS)]  # Use modulo to avoid index error
            colors.append(color)

    return boxes, colors, names

# Draw bounding boxes and labels
def draw_detections(boxes, colors, names, img):
    for box, color, name in zip(boxes, colors, names):
        xmin, ymin, xmax, ymax = box
        cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 2)
        cv2.putText(img, name, (xmin, ymin - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2,
                    lineType=cv2.LINE_AA)
    return img

# Load the appropriate YOLO model based on the version
def load_yolo_model(version="yolov5"):
    if version == "yolov5":
        model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
    elif version == "yolov8":
        model = YOLO("yolov8n.pt")  # YOLOv8 is part of the yolov5 repo starting from v7.0
    elif version == "yolov10":
        model = torch.hub.load('ultralytics/yolov5', 'yolov5m', pretrained=True)  # Placeholder for YOLOv10 (use an appropriate version if available)
    else:
        raise ValueError(f"Unsupported YOLO version: {version}")
    
    model.eval()  # Set to evaluation mode
    model.cpu()
    return model

# Main function for Grad-CAM visualization
# Main function for Grad-CAM visualization
def process_image(image, yolo_versions=["yolov5"]):
    image = np.array(image)
    image = cv2.resize(image, (640, 640))
    rgb_img = image.copy()
    img_float = np.float32(image) / 255
    
    # Image transformation
    transform = transforms.ToTensor()
    tensor = transform(img_float).unsqueeze(0)

    # Initialize list to store result images with captions
    result_images = []

    # Process each selected YOLO model
    for yolo_version in yolo_versions:
        # Load the model based on YOLO version
        model = load_yolo_model(yolo_version)
        if yolo_version == "yolov5":
            target_layers = [model.model.model.model[-2]]  # Assumes last layer is used for Grad-CAM
        else:
            target_layers = [model.model.model[-1]]
        # Run YOLO detection
        results = model([rgb_img])
        boxes, colors, names = parse_detections(results, yolo_version)
        detections_img = draw_detections(boxes, colors, names, rgb_img.copy())

        # Grad-CAM visualization
        cam = EigenCAM(model, target_layers)
        grayscale_cam = cam(tensor)[0, :, :]
        cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)

        # Renormalize Grad-CAM inside bounding boxes
        renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
        for x1, y1, x2, y2 in boxes:
            renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
        renormalized_cam = scale_cam_image(renormalized_cam)
        renormalized_cam_image = show_cam_on_image(img_float, renormalized_cam, use_rgb=True)

        # Concatenate images and prepare the caption
        final_image = np.hstack((rgb_img, cam_image, renormalized_cam_image))
        caption = f"Results using {yolo_version}"
        result_images.append((Image.fromarray(final_image), caption))

    return result_images

interface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Upload an Image"),
        gr.CheckboxGroup(
            choices=["yolov5",  "yolov8", "yolov10"],
            value=["yolov5"],  # Set the default value (YOLOv5 checked by default)
            label="Select Model(s)",
        )
    ],
    outputs = gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
    title="Visualising the key image features that drive decisions with our explainable AI tool.",
    description="XAI: Upload an image to visualize object detection of your models.."
)

if __name__ == "__main__":
    interface.launch()