Seg / app.py
saba000farahani's picture
Update app.py
ed5c9d7 verified
raw
history blame
3.04 kB
import os
import gradio as gr
import numpy as np
import tensorflow as tf
import cv2
from mrcnn.config import Config
from mrcnn import model as modellib
# Set environment variable to avoid floating-point errors
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
# Define Mask R-CNN configuration
class InferenceConfig(Config):
NAME = "mask_rcnn"
NUM_CLASSES = 1 + 80 # Update according to your dataset
GPU_COUNT = 1
IMAGES_PER_GPU = 1
config = InferenceConfig()
# Initialize the Mask R-CNN model
model = modellib.MaskRCNN(mode="inference", config=config, model_dir=os.getcwd())
# Load the Mask R-CNN model weights
model_path = os.path.join('toolkit', 'condmodel_100.h5')
model.load_weights(model_path, by_name=True)
print("Mask R-CNN model loaded successfully with weights.")
# Function to apply Mask R-CNN for image segmentation
def apply_mask_rcnn(image):
try:
# Convert image to RGB (in case of RGBA or grayscale)
if image.shape[2] == 4: # Convert RGBA to RGB
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Resize the image to match the model input size (for inference)
resized_image = cv2.resize(image, (1024, 1024)) # Adjust based on model input requirements
input_image = np.expand_dims(resized_image, axis=0)
# Use Mask R-CNN to predict
result = model.detect(input_image, verbose=0)
r = result[0]
# Create a mask for the detected objects
mask = r['masks']
mask = np.sum(mask, axis=-1) # Combine masks for all objects
# Resize mask back to the original image size
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
# Create a segmentation overlay on the original image
mask_overlay = np.zeros_like(image)
mask_overlay[mask > 0.5] = [0, 255, 0] # Green mask
# Combine the original image with the mask
segmented_image = cv2.addWeighted(image, 1, mask_overlay, 0.5, 0)
return segmented_image
except Exception as e:
print(f"Error in segmentation: {e}")
return image # Return original image if segmentation fails
# Gradio interface definition
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with Mask R-CNN</h1>")
gr.Markdown("Upload an image to see segmentation results using the Mask R-CNN model.")
# Input and output layout
with gr.Row():
with gr.Column():
gr.Markdown("### Upload an Image")
input_image = gr.Image(source="upload", tool="editor", type="numpy", label="Upload an image")
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
with gr.Column():
gr.Markdown("### Segmented Image Output")
output_image = gr.Image(type="numpy", label="Segmented Image")
# Set up button functionality
submit_btn.click(fn=apply_mask_rcnn, inputs=input_image, outputs=output_image)
clear_btn.click(fn=lambda: None)
# Launch the Gradio app
demo.launch()