Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,70 +1,41 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
|
|
|
|
3 |
import numpy as np
|
4 |
-
import tensorflow as tf
|
5 |
import cv2
|
6 |
-
from mrcnn.config import Config
|
7 |
-
from mrcnn import model as modellib
|
8 |
|
9 |
-
#
|
10 |
-
os.
|
|
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
GPU_COUNT = 1
|
17 |
-
IMAGES_PER_GPU = 1
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
# Initialize the Mask R-CNN model
|
22 |
-
model = modellib.MaskRCNN(mode="inference", config=config, model_dir=os.getcwd())
|
23 |
-
|
24 |
-
# Load the Mask R-CNN model weights
|
25 |
-
model_path = os.path.join('toolkit', 'condmodel_100.h5')
|
26 |
-
model.load_weights(model_path, by_name=True)
|
27 |
-
print("Mask R-CNN model loaded successfully with weights.")
|
28 |
-
|
29 |
-
# Function to apply Mask R-CNN for image segmentation
|
30 |
-
def apply_mask_rcnn(image):
|
31 |
try:
|
32 |
-
#
|
33 |
-
|
34 |
-
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
35 |
-
|
36 |
-
# Resize the image to match the model input size (for inference)
|
37 |
-
resized_image = cv2.resize(image, (1024, 1024)) # Adjust based on model input requirements
|
38 |
-
input_image = np.expand_dims(resized_image, axis=0)
|
39 |
-
|
40 |
-
# Use Mask R-CNN to predict
|
41 |
-
result = model.detect(input_image, verbose=0)
|
42 |
-
r = result[0]
|
43 |
-
|
44 |
-
# Create a mask for the detected objects
|
45 |
-
mask = r['masks']
|
46 |
-
mask = np.sum(mask, axis=-1) # Combine masks for all objects
|
47 |
-
|
48 |
-
# Resize mask back to the original image size
|
49 |
-
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
|
50 |
-
|
51 |
-
# Create a segmentation overlay on the original image
|
52 |
-
mask_overlay = np.zeros_like(image)
|
53 |
-
mask_overlay[mask > 0.5] = [0, 255, 0] # Green mask
|
54 |
-
|
55 |
-
# Combine the original image with the mask
|
56 |
-
segmented_image = cv2.addWeighted(image, 1, mask_overlay, 0.5, 0)
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
except Exception as e:
|
61 |
-
print(f"Error in segmentation: {e}")
|
62 |
-
return image # Return original image if segmentation fails
|
63 |
|
64 |
-
# Gradio interface
|
65 |
with gr.Blocks() as demo:
|
66 |
-
gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with
|
67 |
-
gr.Markdown("Upload an image to see segmentation results using the
|
68 |
|
69 |
# Input and output layout
|
70 |
with gr.Row():
|
@@ -79,8 +50,8 @@ with gr.Blocks() as demo:
|
|
79 |
output_image = gr.Image(type="numpy", label="Segmented Image")
|
80 |
|
81 |
# Set up button functionality
|
82 |
-
submit_btn.click(fn=
|
83 |
-
clear_btn.click(fn=lambda: None)
|
84 |
|
85 |
# Launch the Gradio app
|
86 |
demo.launch()
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
+
from ultralytics import YOLO
|
4 |
+
from PIL import Image
|
5 |
import numpy as np
|
|
|
6 |
import cv2
|
|
|
|
|
7 |
|
8 |
+
# Define the directory paths for the model and assets based on the current script location
|
9 |
+
script_dir = os.path.dirname(os.path.abspath(__file__)) # Current script directory
|
10 |
+
yolo_weights_path = os.path.join(script_dir, 'toolkit', 'ALL_best.pt') # Path to YOLO weights
|
11 |
|
12 |
+
# Load the YOLO model
|
13 |
+
model = YOLO(yolo_weights_path)
|
14 |
+
model.fuse() # Optional for optimization
|
15 |
+
print("YOLO model loaded successfully with weights.")
|
|
|
|
|
16 |
|
17 |
+
# Function to perform detection and segmentation with YOLO
|
18 |
+
def apply_yolo_segmentation(image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
try:
|
20 |
+
# Run YOLO on the image and get the results
|
21 |
+
results = model.predict(source=image, save=False) # Use save=False to keep results in memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Retrieve the annotated image from YOLO results
|
24 |
+
result_image = results[0].plot() # `plot()` returns the image with annotations
|
25 |
+
|
26 |
+
# Convert to RGB for Gradio output
|
27 |
+
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
|
28 |
+
|
29 |
+
return result_image_rgb
|
30 |
|
31 |
except Exception as e:
|
32 |
+
print(f"Error in YOLO segmentation: {e}")
|
33 |
+
return image # Return the original image if segmentation fails
|
34 |
|
35 |
+
# Define the Gradio interface
|
36 |
with gr.Blocks() as demo:
|
37 |
+
gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with YOLO</h1>")
|
38 |
+
gr.Markdown("Upload an image to see segmentation results using the YOLO model.")
|
39 |
|
40 |
# Input and output layout
|
41 |
with gr.Row():
|
|
|
50 |
output_image = gr.Image(type="numpy", label="Segmented Image")
|
51 |
|
52 |
# Set up button functionality
|
53 |
+
submit_btn.click(fn=apply_yolo_segmentation, inputs=input_image, outputs=output_image)
|
54 |
+
clear_btn.click(fn=lambda: None, outputs=output_image)
|
55 |
|
56 |
# Launch the Gradio app
|
57 |
demo.launch()
|