File size: 3,044 Bytes
e8c4882
1d2e742
 
 
e8c4882
fef10af
6fbd926
e8c4882
 
 
1d2e742
fef10af
 
 
 
 
 
e8c4882
fef10af
 
ed5c9d7
fef10af
e8c4882
 
fef10af
 
e8c4882
1d2e742
e8c4882
1d2e742
 
 
 
 
 
fef10af
ed5c9d7
1d2e742
 
fef10af
 
 
1d2e742
fef10af
 
 
1d2e742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8c4882
1d2e742
 
 
ed5c9d7
e8c4882
1d2e742
 
 
ed5c9d7
 
 
1d2e742
 
 
ed5c9d7
 
 
 
 
1d2e742
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import gradio as gr
import numpy as np
import tensorflow as tf
import cv2
from mrcnn.config import Config
from mrcnn import model as modellib

# Set environment variable to avoid floating-point errors
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

# Define Mask R-CNN configuration
class InferenceConfig(Config):
    NAME = "mask_rcnn"
    NUM_CLASSES = 1 + 80  # Update according to your dataset
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

config = InferenceConfig()

# Initialize the Mask R-CNN model
model = modellib.MaskRCNN(mode="inference", config=config, model_dir=os.getcwd())

# Load the Mask R-CNN model weights
model_path = os.path.join('toolkit', 'condmodel_100.h5')
model.load_weights(model_path, by_name=True)
print("Mask R-CNN model loaded successfully with weights.")

# Function to apply Mask R-CNN for image segmentation
def apply_mask_rcnn(image):
    try:
        # Convert image to RGB (in case of RGBA or grayscale)
        if image.shape[2] == 4:  # Convert RGBA to RGB
            image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

        # Resize the image to match the model input size (for inference)
        resized_image = cv2.resize(image, (1024, 1024))  # Adjust based on model input requirements
        input_image = np.expand_dims(resized_image, axis=0)

        # Use Mask R-CNN to predict
        result = model.detect(input_image, verbose=0)
        r = result[0]

        # Create a mask for the detected objects
        mask = r['masks']
        mask = np.sum(mask, axis=-1)  # Combine masks for all objects

        # Resize mask back to the original image size
        mask = cv2.resize(mask, (image.shape[1], image.shape[0]))

        # Create a segmentation overlay on the original image
        mask_overlay = np.zeros_like(image)
        mask_overlay[mask > 0.5] = [0, 255, 0]  # Green mask

        # Combine the original image with the mask
        segmented_image = cv2.addWeighted(image, 1, mask_overlay, 0.5, 0)

        return segmented_image

    except Exception as e:
        print(f"Error in segmentation: {e}")
        return image  # Return original image if segmentation fails

# Gradio interface definition
with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with Mask R-CNN</h1>")
    gr.Markdown("Upload an image to see segmentation results using the Mask R-CNN model.")

    # Input and output layout
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Upload an Image")
            input_image = gr.Image(source="upload", tool="editor", type="numpy", label="Upload an image")
            submit_btn = gr.Button("Submit")
            clear_btn = gr.Button("Clear")

        with gr.Column():
            gr.Markdown("### Segmented Image Output")
            output_image = gr.Image(type="numpy", label="Segmented Image")

    # Set up button functionality
    submit_btn.click(fn=apply_mask_rcnn, inputs=input_image, outputs=output_image)
    clear_btn.click(fn=lambda: None)

# Launch the Gradio app
demo.launch()