NeuralVista / app.py
BhumikaMak's picture
Fix: argument mismatch
e3b086b
raw
history blame
1.78 kB
import numpy as np
import cv2
import os
from PIL import Image
import torchvision.transforms as transforms
import gradio as gr
from yolov5 import xai_yolov5
from yolov8 import xai_yolov8s
def process_image(image, yolo_versions=["yolov5"]):
# Convert image from PIL to NumPy array
image = np.array(image)
image = cv2.resize(image, (640, 640))
result_images = []
for yolo_version in yolo_versions:
if yolo_version == "yolov5":
result_images.append(xai_yolov5(image))
elif yolo_version == "yolov8s":
result_images.append(xai_yolov8s(image))
else:
result_images.append((Image.fromarray(image), f"{yolo_version} not yet implemented."))
return result_images
interface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(
type="pil",
label="Upload an Image",
interactive=True,
value=None, # Default value set to None
),
gr.CheckboxGroup(
choices=["yolov5", "yolov8s"],
value=["yolov5"], # Default value set
label="Select Model(s)",
type="value", # Ensure the value is passed as a list of selected models
),
],
outputs=gr.Gallery(label="Results", elem_id="gallery", rows=2, height=500),
title="Explainable AI for YOLO Models",
description="Upload an image or select a sample to visualize YOLO object detection with Grad-CAM.",
examples=[
[os.path.join(os.getcwd(), "data/xai/sample1.jpeg")],
[os.path.join(os.getcwd(), "data/xai/sample2.jpg")],
],
live=True,
)
# Override the input image function to provide fallback
interface.inputs[0].update(value=None)
if __name__ == "__main__":
interface.launch()