Seg / app.py
saba000farahani's picture
Create app.py
1d2e742 verified
raw
history blame
2.79 kB
import gradio as gr
import numpy as np
import cv2
import os
import tensorflow as tf
from tensorflow.keras.models import load_model
# Load the Mask R-CNN model
model_path = os.path.join('toolkit', 'condmodel_100.h5') # Path to your model
mask_rcnn_model = load_model(model_path)
def apply_mask_rcnn(image):
"""
Function to apply the Mask R-CNN model and return the segmented image.
:param image: Input image in numpy array format
:return: Image with segmentation mask overlaid
"""
try:
# Convert image to RGB (in case of RGBA or grayscale)
if image.shape[2] == 4: # Convert RGBA to RGB
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Resize the image to the input size of the model
resized_image = cv2.resize(image, (224, 224)) # Adjust according to model input size
input_image = np.expand_dims(resized_image, axis=0)
# Use Mask R-CNN to predict the mask
prediction = mask_rcnn_model.predict(input_image)
# Assuming the first output is the mask, you may need to adjust based on your model's structure
mask = prediction[0]
mask = np.squeeze(mask) # Remove any unnecessary dimensions
# Resize mask back to the original image size
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
# Create a segmentation overlay on the original image
mask_overlay = np.zeros_like(image)
mask_overlay[mask > 0.5] = [0, 255, 0] # Green mask
# Combine the original image with the mask
segmented_image = cv2.addWeighted(image, 1, mask_overlay, 0.5, 0)
return segmented_image
except Exception as e:
print(f"Error in segmentation: {e}")
return image # Return original image if segmentation fails
# Update Gradio interface for image input/output
inputs = gr.Image(source="upload", tool="editor", type="numpy", label="Upload an image")
outputs = gr.Image(type="numpy", label="Segmented Image")
# Gradio interface definition
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with Mask R-CNN</h1>")
gr.Markdown("Upload an image to see segmentation results using the Mask R-CNN model.")
# Input and output components
with gr.Row():
with gr.Column():
gr.Markdown("### Upload an Image")
inputs.render() # Render the input (image upload)
# Submit button
gr.Button("Submit").click(fn=apply_mask_rcnn, inputs=inputs, outputs=outputs)
gr.Button("Clear").click(fn=lambda: None)
with gr.Column():
gr.Markdown("### Segmented Image Output")
outputs.render() # Render the output (segmented image)
# Launch the Gradio app
demo.launch()