File size: 3,662 Bytes
e8c4882
1d2e742
 
 
e8c4882
 
 
 
1d2e742
e8c4882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d2e742
e8c4882
1d2e742
 
 
 
 
 
e8c4882
 
1d2e742
 
 
e8c4882
1d2e742
e8c4882
 
1d2e742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8c4882
1d2e742
 
 
e8c4882
1d2e742
 
 
 
e8c4882
1d2e742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import gradio as gr
import numpy as np
import tensorflow as tf
import cv2

# Set environment variable to avoid floating-point errors
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

# Define the Mask R-CNN model architecture
def build_mask_rcnn_model():
    input_layer = tf.keras.layers.Input(shape=(224, 224, 3))  # Adjust input shape to match your model
    # Example architecture, you should modify it to match your actual Mask R-CNN model
    x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same')(input_layer)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = tf.keras.layers.UpSampling2D((2, 2))(x)
    x = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = tf.keras.layers.UpSampling2D((2, 2))(x)
    output_layer = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(x)
    
    model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
    return model

# Build the model and load weights
model = build_mask_rcnn_model()

# Load the Mask R-CNN model weights
model_path = os.path.join('toolkit', 'condmodel_100.h5')  # Update with correct path
model.load_weights(model_path)
print("Mask R-CNN model loaded successfully with weights.")

# Function to apply Mask R-CNN for image segmentation
def apply_mask_rcnn(image):
    try:
        # Convert image to RGB (in case of RGBA or grayscale)
        if image.shape[2] == 4:  # Convert RGBA to RGB
            image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)

        # Resize the image to match the model input size
        resized_image = cv2.resize(image, (224, 224))  # Adjust based on the input shape of your model
        input_image = np.expand_dims(resized_image, axis=0)

        # Use Mask R-CNN to predict the mask
        prediction = model.predict(input_image)

        # Extract mask (assumed to be the first output)
        mask = np.squeeze(prediction[0])

        # Resize mask back to the original image size
        mask = cv2.resize(mask, (image.shape[1], image.shape[0]))

        # Create a segmentation overlay on the original image
        mask_overlay = np.zeros_like(image)
        mask_overlay[mask > 0.5] = [0, 255, 0]  # Green mask

        # Combine the original image with the mask
        segmented_image = cv2.addWeighted(image, 1, mask_overlay, 0.5, 0)

        return segmented_image

    except Exception as e:
        print(f"Error in segmentation: {e}")
        return image  # Return original image if segmentation fails

# Gradio interface definition
inputs = gr.Image(source="upload", tool="editor", type="numpy", label="Upload an image")
outputs = gr.Image(type="numpy", label="Segmented Image")

# Gradio app layout
with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>Image Segmentation with Mask R-CNN</h1>")
    gr.Markdown("Upload an image to see segmentation results using the Mask R-CNN model.")
    
    # Input and output layout
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Upload an Image")
            inputs.render()  # Render the input (image upload)

            # Submit button
            gr.Button("Submit").click(fn=apply_mask_rcnn, inputs=inputs, outputs=outputs)
            gr.Button("Clear").click(fn=lambda: None)

        with gr.Column():
            gr.Markdown("### Segmented Image Output")
            outputs.render()  # Render the output (segmented image)

# Launch the Gradio app
demo.launch()