Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
import cv2 | |
from mrcnn.config import Config | |
from mrcnn import model as modellib | |
# Set environment variable to avoid floating-point errors | |
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
# Define Mask R-CNN configuration | |
class InferenceConfig(Config): | |
NAME = "mask_rcnn" | |
NUM_CLASSES = 1 + 80 # Update according to your dataset | |
GPU_COUNT = 1 | |
IMAGES_PER_GPU = 1 | |
config = InferenceConfig() | |
# Initialize the Mask R-CNN model | |
model = modellib.MaskRCNN(mode="inference", config=config, model_dir=os.getcwd()) | |
# Load the Mask R-CNN model weights | |
model_path = os.path.join('toolkit', 'condmodel_100.h5') | |
model.load_weights(model_path, by_name=True) | |
print("Mask R-CNN model loaded successfully with weights.") | |
# Function to apply Mask R-CNN for image segmentation | |
def apply_mask_rcnn(image): | |
try: | |
# Convert image to RGB (in case of RGBA or grayscale) | |
if image.shape[2] == 4: # Convert RGBA to RGB | |
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
# Resize the image to match the model input size (for inference) | |
resized_image = cv2.resize(image, (1024, 1024)) # Adjust based on model input requirements | |
input_image = np.expand_dims(resized_image, axis=0) | |
# Use Mask R-CNN to predict | |
result = model.detect(input_image, verbose=0) | |
r = result[0] | |
# Create a mask for the detected objects | |
mask = r['masks'] | |
mask = np.sum(mask, axis=-1) # Combine masks for all objects | |
# Resize mask back to the original image size | |
mask = cv2.resize(mask, (image.shape[1], image.shape[0])) | |
# Create a segmentation overlay on the original image | |
mask_overlay = np.zeros_like(image) | |
mask_overlay[mask > 0.5] = [0, 255, 0] # Green mask | |
# Combine the original image with the mask | |
segmented_image = cv2.addWeighted(image, 1, mask_overlay, 0.5, 0) | |
return segmented_image | |
except Exception as e: | |
print(f"Error in segmentation: {e}") | |
return image # Return original image if segmentation fails | |
# Gradio interface definition | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with Mask R-CNN</h1>") | |
gr.Markdown("Upload an image to see segmentation results using the Mask R-CNN model.") | |
# Input and output layout | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Upload an Image") | |
input_image = gr.Image(source="upload", tool="editor", type="numpy", label="Upload an image") | |
submit_btn = gr.Button("Submit") | |
clear_btn = gr.Button("Clear") | |
with gr.Column(): | |
gr.Markdown("### Segmented Image Output") | |
output_image = gr.Image(type="numpy", label="Segmented Image") | |
# Set up button functionality | |
submit_btn.click(fn=apply_mask_rcnn, inputs=input_image, outputs=output_image) | |
clear_btn.click(fn=lambda: None) | |
# Launch the Gradio app | |
demo.launch() | |