Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from ultralytics import YOLO | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
# Define the directory paths for the model and assets based on the current script location | |
script_dir = os.path.dirname(os.path.abspath(__file__)) # Current script directory | |
yolo_weights_path = os.path.join(script_dir, 'toolkit', 'ALL_best.pt') # Path to YOLO weights | |
# Load the YOLO model | |
model = YOLO(yolo_weights_path) | |
model.fuse() # Optional for optimization | |
print("YOLO model loaded successfully with weights.") | |
# Function to perform detection and segmentation with YOLO | |
def apply_yolo_segmentation(image): | |
try: | |
# Run YOLO on the image and get the results | |
results = model.predict(source=image, save=False) # Use save=False to keep results in memory | |
# Retrieve the annotated image from YOLO results | |
result_image = results[0].plot() # `plot()` returns the image with annotations | |
# Convert to RGB for Gradio output | |
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB) | |
return result_image_rgb | |
except Exception as e: | |
print(f"Error in YOLO segmentation: {e}") | |
return image # Return the original image if segmentation fails | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with YOLO</h1>") | |
gr.Markdown("Upload an image to see segmentation results using the YOLO model.") | |
# Input and output layout | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Upload an Image") | |
input_image = gr.Image(type="numpy", label="Upload an image") # Removed 'tool' and 'source' arguments | |
submit_btn = gr.Button("Submit") | |
clear_btn = gr.Button("Clear") | |
with gr.Column(): | |
gr.Markdown("### Segmented Image Output") | |
output_image = gr.Image(type="numpy", label="Segmented Image") | |
# Set up button functionality | |
submit_btn.click(fn=apply_yolo_segmentation, inputs=input_image, outputs=output_image) | |
clear_btn.click(fn=lambda: None, outputs=output_image) | |
# Launch the Gradio app | |
demo.launch() | |