Spaces:
Sleeping
Sleeping
Commit
·
e8f8ad5
1
Parent(s):
589c79e
Fix: gradio interface
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ from yolov5 import xai_yolov5
|
|
8 |
from yolov8 import xai_yolov8s
|
9 |
|
10 |
def process_image(image, yolo_versions=["yolov5"]):
|
11 |
-
# Convert image from PIL to NumPy array
|
12 |
image = np.array(image)
|
13 |
image = cv2.resize(image, (640, 640))
|
14 |
|
@@ -28,36 +27,40 @@ sample_images = [
|
|
28 |
os.path.join(os.getcwd(), "data/xai/sample2.jpg")
|
29 |
]
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
8 |
from yolov8 import xai_yolov8s
|
9 |
|
10 |
def process_image(image, yolo_versions=["yolov5"]):
|
|
|
11 |
image = np.array(image)
|
12 |
image = cv2.resize(image, (640, 640))
|
13 |
|
|
|
27 |
os.path.join(os.getcwd(), "data/xai/sample2.jpg")
|
28 |
]
|
29 |
|
30 |
+
def load_sample(sample_name):
|
31 |
+
if sample_name and sample_name in sample_images:
|
32 |
+
return Image.open(sample_images[sample_name])
|
33 |
+
return None
|
34 |
+
|
35 |
+
interface = gr.Interface(
|
36 |
+
fn=process_image,
|
37 |
+
inputs=[
|
38 |
+
gr.Image(type="pil", label="Upload an Image"),
|
39 |
+
gr.CheckboxGroup(
|
40 |
+
choices=["yolov3", "yolov8s"],
|
41 |
+
value=["yolov5"], # Set the default value (YOLOv5 checked by default)
|
42 |
+
label="Select Model(s)",
|
43 |
+
),
|
44 |
+
gr.Dropdown(
|
45 |
+
choices=list(sample_images.keys()),
|
46 |
+
label="Select a Sample Image",
|
47 |
+
type="value",
|
48 |
+
interactive=True,
|
49 |
+
),
|
50 |
+
],
|
51 |
+
outputs=gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
|
52 |
+
title="Visualising the key image features that drive decisions with our explainable AI tool.",
|
53 |
+
description="XAI: Upload an image or select a sample to visualize object detection of your models.",
|
54 |
+
)
|
55 |
+
|
56 |
+
def main_logic(uploaded_image, selected_models, sample_selection):
|
57 |
+
# If the user selects a sample image, use that instead of the uploaded one
|
58 |
+
if sample_selection:
|
59 |
+
image = load_sample(sample_selection)
|
60 |
+
else:
|
61 |
+
image = uploaded_image
|
62 |
+
|
63 |
+
# Call the processing function
|
64 |
+
return process_image(image, selected_models)
|
65 |
+
|
66 |
+
interface.launch()
|