File size: 3,309 Bytes
60af537
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a33d7fc
 
60af537
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

import torch
import cv2
import numpy as np
import torchvision.transforms as transforms
from pytorch_grad_cam import EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
from PIL import Image
import gradio as gr

# Global Color Palette
COLORS = np.random.uniform(0, 255, size=(80, 3))

# Function to parse YOLO detections
def parse_detections(results):
    detections = results.pandas().xyxy[0].to_dict()
    boxes, colors, names = [], [], []
    for i in range(len(detections["xmin"])):
        confidence = detections["confidence"][i]
        if confidence < 0.2:
            continue
        xmin, ymin = int(detections["xmin"][i]), int(detections["ymin"][i])
        xmax, ymax = int(detections["xmax"][i]), int(detections["ymax"][i])
        name, category = detections["name"][i], int(detections["class"][i])
        boxes.append((xmin, ymin, xmax, ymax))
        colors.append(COLORS[category])
        names.append(name)
    return boxes, colors, names

# Draw bounding boxes and labels
def draw_detections(boxes, colors, names, img):
    for box, color, name in zip(boxes, colors, names):
        xmin, ymin, xmax, ymax = box
        cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 2)
        cv2.putText(img, name, (xmin, ymin - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2,
                    lineType=cv2.LINE_AA)
    return img

# Main function for Grad-CAM visualization
def process_image(image):
    image = np.array(image)
    image = cv2.resize(image, (640, 640))
    rgb_img = image.copy()
    img_float = np.float32(image) / 255
    
    # Image transformation
    transform = transforms.ToTensor()
    tensor = transform(img_float).unsqueeze(0)

    # Load YOLOv5 model
    model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
    model.eval()
    model.cpu()
    target_layers = [model.model.model.model[-2]]

    # Run YOLO detection
    results = model([rgb_img])
    boxes, colors, names = parse_detections(results)
    detections_img = draw_detections(boxes, colors, names, rgb_img.copy())

    # Grad-CAM visualization
    cam = EigenCAM(model, target_layers)
    grayscale_cam = cam(tensor)[0, :, :]
    cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)

    # Renormalize Grad-CAM inside bounding boxes
    renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32)
    for x1, y1, x2, y2 in boxes:
        renormalized_cam[y1:y2, x1:x2] = scale_cam_image(grayscale_cam[y1:y2, x1:x2].copy())
    renormalized_cam = scale_cam_image(renormalized_cam)
    renormalized_cam_image = show_cam_on_image(img_float, renormalized_cam, use_rgb=True)

    # Concatenate images
    final_image = np.hstack((rgb_img, cam_image, renormalized_cam_image))

    return Image.fromarray(final_image)

# Gradio Interface
interface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=gr.Image(type="pil", label="Result"),
    title="Visualising the key image features that drive decisions with our explainable AI tool.",
    description="XAI: Upload an image to visualize object detection of your models."
)

if __name__ == "__main__":
    interface.launch()