NeuralVista / app.py
BhumikaMak's picture
Fix: gradio interface
c3d8605
raw
history blame
2.16 kB
import numpy as np
import cv2
import os
from PIL import Image
import torchvision.transforms as transforms
import gradio as gr
from yolov5 import xai_yolov5
from yolov8 import xai_yolov8s
def process_image(image, yolo_versions=["yolov5"]):
image = np.array(image)
image = cv2.resize(image, (640, 640))
result_images = []
for yolo_version in yolo_versions:
if yolo_version == "yolov5":
result_images.append(xai_yolov5(image))
elif yolo_version == "yolov8s":
result_images.append(xai_yolov8s(image))
else:
result_images.append((Image.fromarray(image), f"{yolo_version} not yet implemented."))
return result_images
sample_images = {
"Sample 1": os.path.join(os.getcwd(), "data/xai/sample1.jpeg"),,
"Sample 2": os.path.join(os.getcwd(), "data/xai/sample2.jpg")
}
def load_sample(sample_name):
if sample_name and sample_name in sample_images:
return Image.open(sample_images[sample_name])
return None
interface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload an Image"),
gr.CheckboxGroup(
choices=["yolov3", "yolov8s"],
value=["yolov5"], # Set the default value (YOLOv5 checked by default)
label="Select Model(s)",
),
gr.Dropdown(
choices=list(sample_images.keys()),
label="Select a Sample Image",
type="value",
interactive=True,
),
],
outputs=gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
title="Visualising the key image features that drive decisions with our explainable AI tool.",
description="XAI: Upload an image or select a sample to visualize object detection of your models.",
)
def main_logic(uploaded_image, selected_models, sample_selection):
# If the user selects a sample image, use that instead of the uploaded one
if sample_selection:
image = load_sample(sample_selection)
else:
image = uploaded_image
# Call the processing function
return process_image(image, selected_models)
interface.launch()