File size: 3,117 Bytes
dbd2a18
9b2c5e1
aed0d09
dbd2a18
 
d00769c
 
bcf0395
aed0d09
24f4b49
f7b8e0e
dbd2a18
c3d8605
dbd2a18
c21abf8
aca98af
90ff42e
 
 
aca98af
 
c21abf8
dbd2a18
aed0d09
fa09b4a
9d244f3
fa09b4a
9d244f3
dbd2a18
aed0d09
d00769c
 
d3127bb
dbd2a18
d3127bb
 
aed0d09
d00769c
 
d3127bb
aed0d09
dbd2a18
d3127bb
f504910
dbd2a18
 
bcf0395
 
 
 
 
 
 
dbd2a18
 
aed0d09
dbd2a18
aed0d09
9d244f3
aed0d09
7838123
dbd2a18
 
aed0d09
408a665
 
 
 
 
dbd2a18
b30ea65
dbd2a18
aed0d09
dbd2a18
 
d00769c
dbd2a18
 
 
 
 
 
aed0d09
dbd2a18
 
aed0d09
dbd2a18
 
 
e28f68c
dbd2a18
 
 
 
 
 
20ca536
dbd2a18
aed0d09
408a665
 
 
 
 
7991981
aed0d09
dbd2a18
408a665
dbd2a18
 
408a665
1a11002
aed0d09
dbd2a18
f4731f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import os
from PIL import Image
import cv2
import numpy as np
from yolov5 import xai_yolov5
from yolov8 import xai_yolov8s

# Sample images directory
sample_images = {
    "Sample 1": os.path.join(os.getcwd(), "data/xai/sample1.jpeg"),
    "Sample 2": os.path.join(os.getcwd(), "data/xai/sample2.jpg"),
}

# Function to load sample image
def load_sample_image(sample_name):
    image_path = sample_images.get(sample_name)
    if image_path and os.path.exists(image_path):
        return Image.open(image_path)
    return None

# Function to process the image
def process_image(sample_choice, uploaded_image, yolo_versions):
    # Use uploaded or sample image
    if uploaded_image is not None:
        image = uploaded_image
    else:
        image = load_sample_image(sample_choice)

    # Resize and process the image
    image = np.array(image)
    image = cv2.resize(image, (640, 640))
    result_images = []

    for yolo_version in yolo_versions:
        if yolo_version == "yolov5":
            result_images.append(xai_yolov5(image))
        elif yolo_version == "yolov8s":
            result_images.append(xai_yolov8s(image))
        else:
            result_images.append((Image.fromarray(image), f"{yolo_version} not implemented."))

    return result_images

# Custom CSS for styling (optional)
custom_css = """
 run_button {
    background-color: purple;
    color: white;
    width: 120px;
    border-radius: 5px;
    font-size: 14px;
}
"""

# Gradio UI
with gr.Blocks(css=custom_css) as interface:
    gr.Markdown("# XAI: Visualize Object Detection of Your Models")

    # Default sample
    default_sample = "Sample 1"

    with gr.Row():
        # Left: Select sample or upload image
        with gr.Column():
            sample_selection = gr.Radio(
                choices=list(sample_images.keys()),
                label="Select a Sample Image",
                type="value",
                value=default_sample,
            )

            upload_image = gr.Image(label="Upload an Image", type="pil")

            selected_models = gr.CheckboxGroup(
                choices=["yolov5", "yolov8s"],
                value=["yolov5"],
                label="Select Model(s)",
            )

            run_button = gr.Button("Run", elem_id="run_button")

        # Right: Display sample image
        with gr.Column():
            sample_display = gr.Image(
                value=load_sample_image(default_sample),
                label="Selected Sample Image",
            )

    # Results
    with gr.Row():
        result_gallery = gr.Gallery(
            label="Results",
            elem_id="gallery",
            rows=1,
            height=500,
        )

    # Sample selection update
    sample_selection.change(
        fn=load_sample_image,
        inputs=sample_selection,
        outputs=sample_display,
    )

    # Process image
    run_button.click(
        fn=process_image,
        inputs=[sample_selection, upload_image, selected_models],
        outputs=[result_gallery],
    )

# Launch Gradio app
if __name__ == "__main__":
    interface.launch(share=True)