jens
UI first try
077fc91
raw
history blame
3.75 kB
import os
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
from inference import SegmentPredictor
sam = SegmentPredictor() #service.get_sam(configs.model_type, configs.model_ckpt_path, configs.device)
red = (255,0,0)
blue = (0,0,255)
block = gr.Blocks()
with block:
# States
def point_coords_empty():
return []
def point_labels_empty():
return []
raw_image = gr.Image(type='pil', visible=False)
point_coords = gr.State(point_coords_empty)
point_labels = gr.State(point_labels_empty)
masks = gr.State()
cutout_idx = gr.State(set())
# UI
with gr.Column():
with gr.Row():
input_image = gr.Image(label='Input', height=512, type='pil')
masks_annotated_image = gr.AnnotatedImage(label='Segments', height=512)
cutout_galary = gr.Gallery(label='Cutouts', object_fit='contain', height=512)
with gr.Row():
with gr.Column(scale=1):
point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
reset_btn = gr.Button('Reset')
sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary')
sam_encode_btn = gr.Button('Encode', variant = 'primary')
sam_decode_btn = gr.Button('Predict using points!')
# components
components = {point_coords, point_labels, raw_image, masks, cutout_idx, input_image,
point_label_radio, reset_btn, sam_sgmt_everything_btn, sam_encode_btn,
sam_decode_btn, masks_annotated_image}
# event - init coords
def on_reset_btn_click(raw_image):
return raw_image, point_coords_empty(), point_labels_empty(), None, []
reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False)
def on_input_image_upload(input_image):
# encode image on upload
return input_image, point_coords_empty(), point_labels_empty(), None
input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
# event - set coords
def on_input_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
x, y = evt.index
color = red if point_label_radio == 0 else blue
img = np.array(input_image)
cv2.circle(img, (x, y), 5, color, -1)
img = Image.fromarray(img)
point_coords.append([x,y])
point_labels.append(point_label_radio)
return img, point_coords, point_labels
input_image.select(on_input_image_select, [input_image, point_coords, point_labels, point_label_radio], [input_image, point_coords, point_labels], queue=False)
# event - inference
def on_click_sam_encode_btn(inputs):
image = inputs[raw_image]
sam.encode(image)
def on_click_sam_dencode_btn(inputs):
image = inputs[raw_image]
generated_masks, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
annotated = (image, [(generated_masks[i], f'Mask {i}') for i in range(len(generated_masks))])
return {masks_annotated_image:annotated,
masks: generated_masks,
cutout_idx: set()}
sam_encode_btn.click(on_click_sam_encode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True)
sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True)
#sam_sgmt_everything_btn.click(on_sam_sgmt_everything_click, components, [masks_annotated_image, masks, cutout_idx], queue=True)
if __name__ == '__main__':
block.queue()
block.launch()