Spaces:
Sleeping
Sleeping
import gradio as gr | |
import netron | |
import os | |
import threading | |
import time | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
from yolov5 import xai_yolov5 | |
from yolov8 import xai_yolov8s | |
import requests | |
""" | |
# Sample images directory | |
sample_images = { | |
"Sample 1": os.path.join(os.getcwd(), "data/xai/sample1.jpeg"), | |
"Sample 2": os.path.join(os.getcwd(), "data/xai/sample2.jpg"), | |
} | |
# Preloaded model file path | |
preloaded_model_file = os.path.join(os.getcwd(), "weight_files/yolov5.onnx") # Example path | |
# Function to load sample image | |
def load_sample_image(sample_name): | |
image_path = sample_images.get(sample_name) | |
if image_path and os.path.exists(image_path): | |
return Image.open(image_path) | |
return None | |
# Function to process the image | |
def process_image(sample_choice, uploaded_image, yolo_versions): | |
# Use uploaded or sample image | |
if uploaded_image is not None: | |
image = uploaded_image | |
else: | |
image = load_sample_image(sample_choice) | |
# Resize and process the image | |
image = np.array(image) | |
image = cv2.resize(image, (640, 640)) | |
result_images = [] | |
for yolo_version in yolo_versions: | |
if yolo_version == "yolov5": | |
result_images.append(xai_yolov5(image)) | |
elif yolo_version == "yolov8s": | |
result_images.append(xai_yolov8s(image)) | |
else: | |
result_images.append((Image.fromarray(image), f"{yolo_version} not implemented.")) | |
return result_images | |
# Start Netron backend | |
def start_netron_backend(model_file): | |
def serve_netron(): | |
netron.start(model_file, address=("0.0.0.0", 8080), browse=False) | |
#netron.start(model_file, address="0.0.0.0:8080", browse=False) # Updated Netron arguments | |
# Launch Netron in a separate thread | |
threading.Thread(target=serve_netron, daemon=True).start() | |
# Wait until Netron server is ready | |
def wait_for_netron(url, timeout=10): | |
start_time = time.time() | |
while time.time() - start_time < timeout: | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
return True | |
except requests.ConnectionError: | |
time.sleep(0.5) | |
return False | |
# Check server readiness | |
wait_for_netron("http://localhost:8080/", timeout=15) | |
# View Netron model | |
def view_netron_model(): | |
# Ensure model exists | |
if not os.path.exists(preloaded_model_file): | |
return "Model file not found." | |
# Start Netron backend | |
start_netron_backend(preloaded_model_file) | |
return gr.HTML('<iframe src="http://localhost:8080/" width="100%" height="600px"></iframe>') | |
# Custom CSS for styling (optional) | |
custom_css = """ | |
#run_button { | |
# background-color: purple; | |
# color: white; | |
# width: 120px; | |
# border-radius: 5px; | |
# font-size: 14px; | |
#} | |
""" | |
# Gradio UI | |
with gr.Blocks(css=custom_css) as interface: | |
gr.Markdown("# XAI: Visualize Object Detection of Your Models") | |
# Default sample | |
default_sample = "Sample 1" | |
with gr.Row(): | |
# Left: Select sample or upload image | |
with gr.Column(): | |
sample_selection = gr.Radio( | |
choices=list(sample_images.keys()), | |
label="Select a Sample Image", | |
type="value", | |
value=default_sample, | |
) | |
upload_image = gr.Image(label="Upload an Image", type="pil") | |
selected_models = gr.CheckboxGroup( | |
choices=["yolov5", "yolov8s"], | |
value=["yolov5"], | |
label="Select Model(s)", | |
) | |
run_button = gr.Button("Run", elem_id="run_button") | |
# Right: Display sample image | |
with gr.Column(): | |
sample_display = gr.Image( | |
value=load_sample_image(default_sample), | |
label="Selected Sample Image", | |
) | |
# Results and Netron | |
with gr.Row(): | |
result_gallery = gr.Gallery( | |
label="Results", | |
elem_id="gallery", | |
rows=1, | |
height=500, | |
) | |
# Display Netron iframe | |
netron_display = gr.HTML(view_netron_model()) | |
# Sample selection update | |
sample_selection.change( | |
fn=load_sample_image, | |
inputs=sample_selection, | |
outputs=sample_display, | |
) | |
# Process image | |
run_button.click( | |
fn=process_image, | |
inputs=[sample_selection, upload_image, selected_models], | |
outputs=[result_gallery], | |
) | |
# Launch Gradio app | |
if __name__ == "__main__": | |
interface.launch(share=True) | |
""" | |
import gradio as gr | |
import torch | |
from ultralytics import YOLO | |
import os | |
# Assuming the YOLOv5 model is located in the same directory as this script | |
model_path = os.path.join(os.getcwd(), "weight_files/yolov5.onnx") # Replace with the actual model directory name | |
def visualize_yolov5(): | |
""" | |
Visualizes the given YOLOv5 model using Netron. | |
""" | |
try: | |
# Load the YOLOv5 model | |
model = YOLO(model_path) | |
# Extract the PyTorch model | |
pytorch_model = model.model | |
# Save the PyTorch model to a temporary file | |
temp_model_path = "temp_model" | |
torch.save(pytorch_model.state_dict(), os.path.join(temp_model_path, "pytorch_model.bin")) | |
# Run Netron | |
os.system(f"netron {temp_model_path}") | |
except Exception as e: | |
print(f"Error visualizing model: {e}") | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=visualize_yolov5, | |
inputs=None, # No input required | |
outputs="text", | |
title="Netron YOLOv5 Model Visualization", | |
description="Visualize the YOLOv5 model." | |
) | |
# Launch the Gradio app | |
iface.launch(share=True) |