File size: 4,132 Bytes
dbd2a18
 
9b2c5e1
dbd2a18
 
57419d8
dbd2a18
 
 
57419d8
dbd2a18
24f4b49
f7b8e0e
dbd2a18
c3d8605
dbd2a18
 
 
 
aca98af
dbd2a18
90ff42e
 
 
aca98af
 
dbd2a18
 
fa09b4a
 
 
dbd2a18
 
d3127bb
 
 
dbd2a18
d3127bb
 
 
 
 
 
 
dbd2a18
d3127bb
f504910
dbd2a18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e7c847
dbd2a18
7838123
dbd2a18
 
408a665
 
 
 
 
 
dbd2a18
b30ea65
dbd2a18
408a665
 
dbd2a18
408a665
dbd2a18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ca536
dbd2a18
 
 
408a665
 
 
 
 
7991981
dbd2a18
408a665
dbd2a18
 
408a665
1a11002
dbd2a18
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
import netron
import os
import threading
import time
from PIL import Image
import cv2
import numpy as np
import torch

# Sample images directory
sample_images = {
    "Sample 1": os.path.join(os.getcwd(), "data/xai/sample1.jpeg"),
    "Sample 2": os.path.join(os.getcwd(), "data/xai/sample2.jpg"),
}

# Preloaded model file path (update this path as needed)
preloaded_model_file = os.path.join(os.getcwd(), "weight_files/yolov5.onnx")  # Example path

def load_sample_image(sample_name):
    """Load a sample image based on user selection."""
    image_path = sample_images.get(sample_name)
    if image_path and os.path.exists(image_path):
        return Image.open(image_path)
    return None

def process_image(sample_choice, uploaded_image, yolo_versions):
    """Process the image using selected YOLO models."""
    if uploaded_image is not None:
        image = uploaded_image  # Use the uploaded image
    else:
        image = load_sample_image(sample_choice)  # Use selected sample image

    image = np.array(image)
    image = cv2.resize(image, (640, 640))
    result_images = []

    for yolo_version in yolo_versions:
        if yolo_version == "yolov5":
            result_images.append(xai_yolov5(image)) 
        elif yolo_version == "yolov8s":
            result_images.append(xai_yolov8s(image))
        else:
            result_images.append((Image.fromarray(image), f"{yolo_version} not yet implemented."))

    return result_images

def serve_netron(model_file):
    """Start the Netron server in a separate thread."""
    threading.Thread(target=netron.start, args=(model_file,), daemon=True).start()
    time.sleep(1)  # Give some time for the server to start
    return "http://localhost:8080"  # Default Netron URL
def view_model():
    """Handle model visualization using preloaded model file."""
    if not os.path.exists(preloaded_model_file):
        return "Model file not found."
    
    netron_url = serve_netron(preloaded_model_file)
    return f'<iframe src="{netron_url}" width="100%" height="600px"></iframe>'

# Custom CSS for styling (optional)
custom_css = """
#run_button {
    background-color: purple;
    color: white;
    width: 120px;
    border-radius: 5px;
    font-size: 14px;
}
"""

with gr.Blocks(css=custom_css) as interface:
    gr.Markdown("# XAI: Visualize Object Detection of Your Models")
    
    default_sample = "Sample 1"

    with gr.Row():
        # Left side: Sample selection and upload image
        with gr.Column():
            sample_selection = gr.Radio(
                choices=list(sample_images.keys()),
                label="Select a Sample Image",
                type="value",
                value=default_sample,
            )

            upload_image = gr.Image(
                label="Upload an Image",
                type="pil",  
            )

            selected_models = gr.CheckboxGroup(
                choices=["yolov5", "yolov8s"],
                value=["yolov5"],
                label="Select Model(s)",
            )

            run_button = gr.Button("Run", elem_id="run_button")

        with gr.Column():
            sample_display = gr.Image(
                value=load_sample_image(default_sample),  
                label="Selected Sample Image",
            )

    # Below the sample image, display results and architecture side by side
    with gr.Row():
        result_gallery = gr.Gallery(
            label="Results",
            elem_id="gallery",
            rows=1,
            height=500,
        )

        netron_display = gr.HTML(label="Netron Visualization")

    sample_selection.change(
        fn=load_sample_image,
        inputs=sample_selection,
        outputs=sample_display,
    )

    run_button.click(
        fn=process_image,
        inputs=[sample_selection, upload_image, selected_models],
        outputs=[result_gallery],
    )

    # Update Netron display when the interface loads
    netron_display.value = view_model()  # Directly set the value

# Launching Gradio app and handling Netron visualization separately.
if __name__ == "__main__":
    interface.launch(share=True)