Seg / app.py
saba000farahani's picture
Update app.py
fef10af verified
raw
history blame
3.1 kB
import os
import gradio as gr
import numpy as np
import tensorflow as tf
import cv2
from mrcnn import model as modellib
from mrcnn.config import Config
# Set environment variable to avoid floating-point errors
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
# Define Mask R-CNN configuration
class InferenceConfig(Config):
NAME = "mask_rcnn"
NUM_CLASSES = 1 + 80 # Update according to your dataset
GPU_COUNT = 1
IMAGES_PER_GPU = 1
config = InferenceConfig()
# Rebuild the Mask R-CNN model
model = modellib.MaskRCNN(mode="inference", config=config, model_dir=os.getcwd())
# Load the Mask R-CNN model weights
model_path = os.path.join('toolkit', 'condmodel_100.h5')
model.load_weights(model_path, by_name=True)
print("Mask R-CNN model loaded successfully with weights.")
# Function to apply Mask R-CNN for image segmentation
def apply_mask_rcnn(image):
try:
# Convert image to RGB (in case of RGBA or grayscale)
if image.shape[2] == 4: # Convert RGBA to RGB
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Resize the image to match the model input size (for inference)
resized_image = cv2.resize(image, (1024, 1024)) # Adjust based on the input shape of your model
input_image = np.expand_dims(resized_image, axis=0)
# Use Mask R-CNN to predict
result = model.detect(input_image, verbose=0)
r = result[0]
# Create a mask for the detected objects
mask = r['masks']
mask = np.sum(mask, axis=-1) # Combine masks for all objects
# Resize mask back to the original image size
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
# Create a segmentation overlay on the original image
mask_overlay = np.zeros_like(image)
mask_overlay[mask > 0.5] = [0, 255, 0] # Green mask
# Combine the original image with the mask
segmented_image = cv2.addWeighted(image, 1, mask_overlay, 0.5, 0)
return segmented_image
except Exception as e:
print(f"Error in segmentation: {e}")
return image # Return original image if segmentation fails
# Gradio interface definition
inputs = gr.Image(source="upload", tool="editor", type="numpy", label="Upload an image")
outputs = gr.Image(type="numpy", label="Segmented Image")
# Gradio app layout
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with Mask R-CNN</h1>")
gr.Markdown("Upload an image to see segmentation results using the Mask R-CNN model.")
# Input and output layout
with gr.Row():
with gr.Column():
gr.Markdown("### Upload an Image")
inputs.render() # Render the input (image upload)
# Submit button
gr.Button("Submit").click(fn=apply_mask_rcnn, inputs=inputs, outputs=outputs)
gr.Button("Clear").click(fn=lambda: None)
with gr.Column():
gr.Markdown("### Segmented Image Output")
outputs.render() # Render the output (segmented image)
# Launch the Gradio app
demo.launch()