fradinho's picture
Update app.py
9f12d03
raw
history blame
22.9 kB
import gradio as gr
from PIL import Image
from patchify import patchify, unpatchify
import numpy as np
from skimage.io import imshow, imsave
import tensorflow
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.utils import normalize, to_categorical
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras import layers
size = 1024
pach_size = 256
def jacard(y_true, y_pred):
y_true_c = K.flatten(y_true)
y_pred_c = K.flatten(y_pred)
intersection = K.sum(y_true_c * y_pred_c)
return (intersection + 1.0) / (K.sum(y_true_c) + K.sum(y_pred_c) - intersection + 1.0)
def bce_dice(y_true, y_pred):
bce = tf.keras.losses.BinaryCrossentropy()
return bce(y_true, y_pred) - K.log(jacard(y_true, y_pred))
def upsample(X,X_side):
"""
Upsampling and concatination with the side path
"""
X = Conv2DTranspose(int(X.shape[1]/2), (3, 3), strides=(2, 2), padding='same')(X)
#X = tf.keras.layers.UpSampling2D((2,2))(X)
concat = tf.keras.layers.Concatenate()([X,X_side])
return concat
def gating_signal(input, out_size, batch_norm=False):
"""
resize the down layer feature map into the same dimension as the up layer feature map
using 1x1 conv
:return: the gating feature map with the same dimension of the up layer feature map
"""
x = layers.Conv2D(out_size, (1, 1), padding='same')(input)
if batch_norm:
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
return x
def attention_block(x, gating, inter_shape):
shape_x = K.int_shape(x)
shape_g = K.int_shape(gating)
# Getting the x signal to the same shape as the gating signal
theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same')(x) # 16
shape_theta_x = K.int_shape(theta_x)
# Getting the gating signal to the same number of filters as the inter_shape
phi_g = layers.Conv2D(inter_shape, (1, 1), padding='same')(gating)
upsample_g = layers.Conv2DTranspose(inter_shape, (3, 3),
strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),
padding='same')(phi_g) # 16
concat_xg = layers.add([upsample_g, theta_x])
act_xg = layers.Activation('relu')(concat_xg)
psi = layers.Conv2D(1, (1, 1), padding='same')(act_xg)
sigmoid_xg = layers.Activation('sigmoid')(psi)
shape_sigmoid = K.int_shape(sigmoid_xg)
upsample_psi = layers.UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(sigmoid_xg) # 32
upsample_psi = repeat_elem(upsample_psi, shape_x[3])
y = layers.multiply([upsample_psi, x])
result = layers.Conv2D(shape_x[3], (1, 1), padding='same')(y)
result_bn = layers.BatchNormalization()(result)
return result_bn
def repeat_elem(tensor, rep):
# lambda function to repeat Repeats the elements of a tensor along an axis
#by a factor of rep.
# If tensor has shape (None, 256,256,3), lambda will return a tensor of shape
#(None, 256,256,6), if specified axis=3 and rep=2.
return layers.Lambda(lambda x, repnum: K.repeat_elements(x, repnum, axis=3),
arguments={'repnum': rep})(tensor)
activation_funtion = 'relu'
recurrent_repeats = 2 * 4
FILTER_NUM = 4 * 4
axis = 3
act_func = 'relu'
filters = 64
def encoder(inputs, input_tensor):
#Contraction path
conv_1 = Conv2D(filters, (3, 3), activation='relu', padding='same')(inputs)
conv_1 = BatchNormalization()(conv_1)
conv_1 = Dropout(0.1)(conv_1)
conv_1 = Conv2D(filters, (3, 3), activation='relu', padding='same')(conv_1)
conv_1 = BatchNormalization()(conv_1)
pool_1 = MaxPooling2D((2, 2))(conv_1)
conv_2 = Conv2D(2*filters, (3, 3), activation='relu', padding='same')(pool_1)
conv_2 = BatchNormalization()(conv_2)
conv_2 = Dropout(0.1)(conv_2)
conv_2 = Conv2D(2*filters, (3, 3), activation='relu', padding='same')(conv_2)
conv_2 = BatchNormalization()(conv_2)
pool_2 = MaxPooling2D((2, 2))(conv_2)
conv_3 = Conv2D(4*filters, (3, 3), activation='relu', padding='same')(pool_2)
conv_3 = BatchNormalization()(conv_3)
conv_3 = Dropout(0.1)(conv_3)
conv_3 = Conv2D(4*filters, (3, 3), activation='relu', padding='same')(conv_3)
conv_3 = BatchNormalization()(conv_3)
pool_3 = MaxPooling2D((2, 2))(conv_3)
conv_4 = Conv2D(8*filters, (3, 3), activation='relu', padding='same')(pool_3)
conv_4 = BatchNormalization()(conv_4)
conv_4 = Dropout(0.1)(conv_4)
conv_4 = Conv2D(8*filters, (3, 3), activation='relu', padding='same')(conv_4)
conv_4 = BatchNormalization()(conv_4)
pool_4 = MaxPooling2D(pool_size=(2, 2))(conv_4)
conv_5 = Conv2D(16*filters, (3, 3), activation='relu', padding='same')(pool_4)
conv_5 = BatchNormalization()(conv_5)
conv_5 = Dropout(0.1)(conv_5)
model = Model(inputs=[input_tensor], outputs=[conv_5, conv_4, conv_3, conv_2, conv_1])
return model
def encoder_unet(inputs):
## Project residual
# residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
# previous_block_activation
# )
#x = layers.add([x, residual]) # Add back residual
#Contraction path
#Contraction path
conv_11 = Conv2D(filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
conv_11 = BatchNormalization()(conv_11)
conv_11 = Dropout(0.2)(conv_11)
conv_1 = Conv2D(filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_11)
conv_1 = BatchNormalization()(conv_1)
#conv_1 = concatenate([resblock(conv_11, 64), conv_1], axis=3)
#conv_1 = Dropout(0.2)(conv_1)
#pool_1 = layers.GaussianNoise(0.1+np.random.random()*0.4)(conv_1)
pool_1 = MaxPooling2D((2, 2))(conv_1)
conv_2 = Conv2D(2*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool_1)
conv_2 = BatchNormalization()(conv_2)
conv_2 = Dropout(0.2)(conv_2)
conv_2 = Conv2D(2*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_2)
conv_2 = BatchNormalization()(conv_2)
#conv_2 = Dropout(0.2)(conv_2)
#conv_2 = Conv2D(2*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_2)
#conv_2 = concatenate([resblock(pool_1, 128), conv_2], axis=3)
#conv_2 = BatchNormalization()(conv_2)
#conv_2 = Dropout(0.2)(conv_2)
#pool_2 = layers.GaussianNoise(0.1+np.random.random()*0.4)(conv_2)
pool_2 = MaxPooling2D((2, 2))(conv_2)
conv_3 = Conv2D(4*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool_2)
conv_3 = BatchNormalization()(conv_3)
conv_3 = Dropout(0.2)(conv_3)
conv_3 = Conv2D(4*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_3)
conv_3 = BatchNormalization()(conv_3)
#conv_3 = Dropout(0.2)(conv_3)
#conv_3 = Conv2D(4*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_3)
#conv_3 = BatchNormalization()(conv_3)
#conv_3 = Dropout(0.2)(conv_3)
conv_3 = Conv2D(4*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_3)
conv_3 = BatchNormalization()(conv_3)
#conv_3 = concatenate([resblock(pool_2, 256), conv_3], axis=3)
#conv_3 = Dropout(0.2)(conv_3)
#pool_3 = layers.GaussianNoise(0.1+np.random.random()*0.4)(conv_3)
pool_3 = MaxPooling2D((2, 2))(conv_3)
conv_4 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool_3)
conv_4 = BatchNormalization()(conv_4)
conv_4 = Dropout(0.2)(conv_4)
#conv_4 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_4)
#conv_4 = BatchNormalization()(conv_4)
#conv_4 = Dropout(0.2)(conv_4)
conv_4 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_4)
conv_4 = BatchNormalization()(conv_4)
conv_4 = Dropout(0.2)(conv_4)
#conv_4 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_4)
#conv_4 = BatchNormalization()(conv_4)
#conv_4 = Dropout(0.2)(conv_4)
conv_4 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_4)
conv_4 = BatchNormalization()(conv_4)
#conv_4 = concatenate([resblock(pool_3, 512), conv_4], axis=3)
#conv_4 = Dropout(0.2)(conv_4)
#pool_4 = layers.GaussianNoise(0.1+np.random.random()*0.4)(conv_4)
pool_4 = MaxPooling2D(pool_size=(2, 2))(conv_4)
conv_44 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool_4)
conv_44 = BatchNormalization()(conv_44)
conv_44 = Dropout(0.2)(conv_44)
conv_44 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_44)
conv_44 = BatchNormalization()(conv_44)
conv_44 = Dropout(0.2)(conv_44)
#conv_44 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_44)
#conv_44 = BatchNormalization()(conv_44)
#conv_44 = Dropout(0.2)(conv_44)
#conv_4 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_4)
#conv_4 = BatchNormalization()(conv_4)
#conv_4 = Dropout(0.2)(conv_4)
conv_44 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_44)
conv_44 = BatchNormalization()(conv_44)
#conv_4 = concatenate([resblock(pool_3, 512), conv_4], axis=3)
#conv_44 = Dropout(0.2)(conv_44)
#pool_4 = layers.GaussianNoise(0.1+np.random.random()*0.4)(conv_4)
pool_44 = MaxPooling2D(pool_size=(2, 2))(conv_44)
conv_5 = Conv2D(16*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool_44)
conv_5 = BatchNormalization()(conv_5)
#conv_5 = Conv2D(16*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_5)
#conv_5 = BatchNormalization()(conv_5)
#conv_5 = concatenate([resblock(pool_4, 1024), conv_5], axis=3)
#conv_5 = Dropout(0.2)(conv_5)
#conv_5 = layers.GaussianNoise(0.1)(conv_5)
model = Model(inputs=[inputs], outputs=[conv_5, conv_44, conv_3, conv_2, conv_1])
return model
def decoder(inputs, input_tensor):
#Expansive path
gating_64 = gating_signal(inputs[0], 16*FILTER_NUM, True)
att_64 = attention_block(inputs[1], gating_64, 16*FILTER_NUM)
up_stage_2 = upsample(inputs[0],inputs[1])
#u6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(inputs[0])
u6 = concatenate([up_stage_2, att_64], axis=3)
#u6 = concatenate([att_5, u6])
#conv_6 = Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
#conv_6 = BatchNormalization()(conv_6)
#conv_6 = Dropout(0.2)(conv_6)
#conv_6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv_6)
#conv_6 = Dropout(0.2)(conv_6)
conv_6 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(u6)
conv_6 = BatchNormalization()(conv_6)
#conv_6 = Dropout(0.2)(conv_6)
conv_6 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_6)
conv_6 = BatchNormalization()(conv_6)
#conv_6 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_6)
#conv_6 = BatchNormalization()(conv_6)
#conv_6 = Dropout(0.2)(conv_6)
#conv_6 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_6)
#conv_6 = BatchNormalization()(conv_6)
#conv_6 = Dropout(0.2)(conv_6)
conv_6 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_6)
conv_6 = BatchNormalization()(conv_6)
conv_6 = Dropout(0.2)(conv_6)
up_stage_22 = Conv2DTranspose(int(conv_6.shape[1]/2), (3, 3), strides=(2, 2), padding='same')(conv_6)
conv_66 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(up_stage_22)
conv_66 = BatchNormalization()(conv_66)
#conv_6 = Dropout(0.2)(conv_6)
#conv_66 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_66)
#conv_66 = BatchNormalization()(conv_66)
conv_66 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_66)
conv_66 = BatchNormalization()(conv_66)
#conv_6 = Dropout(0.2)(conv_6)
#conv_66 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_66)
#conv_66 = BatchNormalization()(conv_66)
#conv_6 = Dropout(0.2)(conv_6)
conv_66 = Conv2D(8*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_66)
conv_66 = BatchNormalization()(conv_66)
conv_66 = Dropout(0.2)(conv_66)
gating_128 = gating_signal(conv_66, 8*FILTER_NUM, True)
att_128 = attention_block(inputs[2], gating_128, 8*FILTER_NUM)
up_stage_3 = upsample(conv_66,inputs[2])
#u7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv_6)
u7 = concatenate([up_stage_3, att_128], axis=3)
#conv_7 = Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
#conv_7 = BatchNormalization()(conv_7)
#conv_7 = Dropout(0.2)(conv_7)
#conv_7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv_7)
#conv_7 = Dropout(0.2)(conv_7)
conv_7 = Conv2D(4*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(u7)
conv_7 = BatchNormalization()(conv_7)
#conv_7 = Dropout(0.2)(conv_7)
conv_7 = Conv2D(4*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_7)
conv_7 = BatchNormalization()(conv_7)
#conv_7 = Conv2D(4*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_7)
#conv_7 = BatchNormalization()(conv_7)
#conv_7 = Dropout(0.2)(conv_7)
conv_7 = Conv2D(4*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_7)
conv_7 = BatchNormalization()(conv_7)
conv_7 = Dropout(0.2)(conv_7)
gating_256 = gating_signal(conv_7, 4*FILTER_NUM, True)
att_256 = attention_block(inputs[3], gating_256, 4*FILTER_NUM)
up_stage_4 = upsample(conv_7,inputs[3])
#u8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv_7)
u8 = concatenate([up_stage_4, att_256], axis=3)
#conv_8 = Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
#conv_8 = BatchNormalization()(conv_8)
#conv_8 = Dropout(0.1)(conv_8)
conv_8 = Conv2D(2*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(u8)
conv_8 = BatchNormalization()(conv_8)
#conv_8 = Dropout(0.2)(conv_8)
#conv_8 = Conv2D(2*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(u8)
#conv_8 = BatchNormalization()(conv_8)
#conv_8 = Dropout(0.2)(conv_8)
conv_8 = Conv2D(2*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_8)
conv_8 = BatchNormalization()(conv_8)
conv_8 = Dropout(0.2)(conv_8)
gating_512 = gating_signal(conv_8, 2*FILTER_NUM, True)
att_512 = attention_block(inputs[4], gating_512, 2*FILTER_NUM)
up_stage_5 = upsample(conv_8,inputs[4])
#u9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv_8)
u9 = concatenate([up_stage_5, att_512], axis=3)
conv_9 = Conv2D(1*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(u9)
conv_9 = BatchNormalization()(conv_9)
#conv_9 = Dropout(0.2)(conv_9)
conv_9 = Conv2D(1*filters, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv_9)
conv_9 = BatchNormalization()(conv_9)
conv_9 = Dropout(0.2)(conv_9)
model = Model(inputs=[input_tensor], outputs=[conv_9])
return model
def unet_2( n_classes=2, height=pach_size, width=pach_size, channels=3, metrics = ['accuracy']):
inputs = Input((height, width, channels))
encode = encoder_unet(inputs)
decode = decoder(encode.output, inputs)
outputs = decode.output
#outputs = Conv2D(n_classes, (1, 1), activation='softmax', padding='same', kernel_initializer='he_normal')(decode.output)
#outputs = tf.reshape(encode_2.output[0], [None, 16, 16, 256])
model = Model(inputs=[inputs], outputs=[outputs])
if n_classes <= 2:
model.compile(optimizer = Adam(lr = 1e-3), loss = 'binary_crossentropy', metrics = metrics)
elif n_classes > 2:
model.compile(optimizer = Adam(lr = 1e-3), loss = 'categorical_crossentropy', metrics = metrics)
#model.summary()
return model
def unet_enssemble(n_classes=2, height=64, width=64, channels=3, metrics = ['accuracy']):
x = Input((height, width, channels))
model10 = unet_2( n_classes=n_classes, height = height, width = width, channels = 3)
out = model10(x)
outputs = Conv2D(n_classes, (1, 1), activation='softmax', padding='same')(out)
#model = Model(inputs=[inputs], outputs=[encode.output])
model = Model(inputs=[x], outputs=[outputs])
#model = Model(inputs=[model7.input, model11.input], outputs=[outputs])
if n_classes <= 2:
model.compile(optimizer = Adam(lr = 1e-3), loss = 'binary_crossentropy', metrics = metrics)
elif n_classes > 2:
model.compile(optimizer = Adam(lr = 1e-3), loss = 'categorical_crossentropy', metrics = metrics)
#if summary:
# model.summary()
return model
n_classes = 23
n_channels = 3
model = unet_enssemble(n_classes=n_classes, height = pach_size, width = pach_size, channels = n_channels)
size = 1024
pach_size = 256
def predict_2(image):
image = Image.fromarray(image).resize((size,size))
image = np.array(image)
stride = 1
steps = int(pach_size/stride)
patches_img = patchify(image, (pach_size, pach_size, 3), step=steps) #Step=256 for 256 patches means no overlap
patches_img = patches_img[:,:,0,:,:,:]
patched_prediction = []
for i in range(patches_img.shape[0]):
for j in range(patches_img.shape[1]):
single_patch_img = patches_img[i,j,:,:,:]
single_patch_img = single_patch_img/255
single_patch_img = np.expand_dims(single_patch_img, axis=0)
pred = model.predict(single_patch_img)
# Postprocess the mask
pred = np.argmax(pred, axis=3)
#print(pred.shape)
pred = pred[0, :,:]
patched_prediction.append(pred)
patched_prediction = np.reshape(patched_prediction, [patches_img.shape[0], patches_img.shape[1],
patches_img.shape[2], patches_img.shape[3]])
unpatched_prediction = unpatchify(patched_prediction, (image.shape[0], image.shape[1]))
unpatched_prediction = targets_classes_colors[unpatched_prediction]
return 'Predicted Masked Image', unpatched_prediction
targets_classes_colors = np.array([[ 0, 0, 0],
[128, 64, 128],
[130, 76, 0],
[ 0, 102, 0],
[112, 103, 87],
[ 28, 42, 168],
[ 48, 41, 30],
[ 0, 50, 89],
[107, 142, 35],
[ 70, 70, 70],
[102, 102, 156],
[254, 228, 12],
[254, 148, 12],
[190, 153, 153],
[153, 153, 153],
[255, 22, 96],
[102, 51, 0],
[ 9, 143, 150],
[119, 11, 32],
[ 51, 51, 0],
[190, 250, 190],
[112, 150, 146],
[ 2, 135, 115],
[255, 0, 0]])
class_weights = {0: 0.1,
1: 0.1,
2: 2.171655596616696,
3: 0.1,
4: 0.1,
5: 2.2101197049812593,
6: 11.601519937899578,
7: 7.99072122367673,
8: 0.1,
9: 0.1,
10: 2.5426918173402457,
11: 11.187574445057574,
12: 241.57620214903147,
13: 9.234779790464515,
14: 1077.2745952165694,
15: 7.396021659003857,
16: 855.6730643687165,
17: 6.410869993189135,
18: 42.0186736125025,
19: 2.5648760196752947,
20: 4.089194047656931,
21: 27.984593442818955,
22: 2.0509251319694712}
weight_list = list(class_weights.values())
def weighted_categorical_crossentropy(weights):
weights = weight_list
def wcce(y_true, y_pred):
Kweights = K.constant(weights)
if not tf.is_tensor(y_pred): y_pred = K.constant(y_pred)
y_true = K.cast(y_true, y_pred.dtype)
return bce_dice(y_true, y_pred) * K.sum(y_true * Kweights, axis=-1)
return wcce
# Load the model
#model = tf.keras.models.load_model("model.h5", custom_objects={"jacard":jacard, "wcce":weighted_categorical_crossentropy})
#model = tf.keras.models.load_model("model_2.h5", custom_objects={"jacard":jacard, "bce_dice":bce_dice})
model = model.load_weights("model_2_A (1).h5")
# Create a user interface for the model
my_app = gr.Blocks()
with my_app:
gr.Markdown("Statellite Image Segmentation Application UI with Gradio")
with gr.Tabs():
with gr.TabItem("Select your image"):
with gr.Row():
with gr.Column():
img_source = gr.Image(label="Please select source Image")
source_image_loader = gr.Button("Load above Image")
with gr.Column():
output_label = gr.Label(label="Image Info")
img_output = gr.Image(label="Image Output")
source_image_loader.click(
predict_2,
[
img_source
],
[
output_label,
img_output
]
)
my_app.launch(debug=True, share=True)
my_app.close()