jens
fix
ae29f3e
raw
history blame
5.83 kB
import os
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
from inference import SegmentPredictor, DepthPredictor
from utils import generate_PCL, PCL3, point_cloud
sam = SegmentPredictor()
dpt = DepthPredictor()
red = (255,0,0)
blue = (0,0,255)
annos = []
block = gr.Blocks()
with block:
# States
def point_coords_empty():
return []
def point_labels_empty():
return []
image_edit_trigger = gr.State(True)
point_coords = gr.State(point_coords_empty)
point_labels = gr.State(point_labels_empty)
masks = gr.State([])
cutout_idx = gr.State(set())
# UI
with gr.Column():
with gr.Row():
with gr.Column():
input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False
sam_encode_btn = gr.Button('Encode', variant='primary')
sam_encode_status = gr.Label('Not encoded yet')
with gr.Row():
prompt_image = gr.Image(label='Segments')
prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
everything_image = gr.AnnotatedImage(label='Everything')
with gr.Row():
with gr.Column():
pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0])
with gr.Row():
n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples')
cube_size = gr.Slider(minimum=0.000001, maximum=0.001, step=0.000001, default=0.00001, label='Cube size')
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
text = gr.Textbox(label='Mask Name')
reset_btn = gr.Button('New Mask')
sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary')
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary')
# components
components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image,
point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
sam_decode_btn, depth_reconstruction_btn, prompt_image, prompt_lbl_image, n_samples, cube_size}
# event - init coords
def on_reset_btn_click(input_image):
return input_image, point_coords_empty(), point_labels_empty(), None, []
reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
def on_input_image_upload(input_image):
print("encoding")
# encode image on upload
sam.encode(input_image)
print("encoding done")
return input_image, point_coords_empty(), point_labels_empty(), None
input_image.upload(on_input_image_upload, [input_image], [input_image, point_coords, point_labels], queue=False)
# event - set coords
def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
x, y = evt.index
color = red if point_label_radio == 0 else blue
if prompt_image is None:
prompt_image = np.array(input_image.copy())
cv2.circle(prompt_image, (x, y), 5, color, -1)
point_coords.append([x,y])
point_labels.append(point_label_radio)
generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
return [ prompt_image,
(input_image, [(generated_mask, "Mask")]),
point_coords,
point_labels ]
prompt_image.select(on_prompt_image_select,
[input_image, prompt_image, point_coords, point_labels, point_label_radio],
[prompt_image, prompt_lbl_image, point_coords, point_labels], queue=False)
def on_click_sam_encode_btn(inputs):
print("encoding")
# encode image on click
sam.encode(inputs[input_image])
print("encoding done")
return {sam_encode_status: 'Image Encoded!',
prompt_image: inputs[input_image]}
sam_encode_btn.click(on_click_sam_encode_btn, components, [sam_encode_status, prompt_image], queue=False)
def on_click_sam_dencode_btn(inputs):
print("inferencing")
image = inputs[input_image]
generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
inputs[masks].append((generated_mask, inputs[text]))
print(inputs[masks][0])
return {prompt_image: (image, inputs[masks])}
sam_decode_btn.click(on_click_sam_dencode_btn, components, [prompt_image, masks, cutout_idx], queue=True)
def on_depth_reconstruction_btn_click(inputs):
print("depth reconstruction")
image = inputs[input_image]
path = dpt.generate_obj_masks(image=image, n_samples=inputs[n_samples], cube_size=inputs[cube_size], masks=inputs[masks])
return {pcl_figure: path}
depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False)
def on_sam_sgmt_everything_btn_click(inputs):
print("segmenting everything")
image = inputs[input_image]
sam_masks = sam.segment_everything(image)
print(image)
print(sam_masks)
return {everything_image: (image, sam_masks)}
sam_sgmt_everything_btn.click(on_sam_sgmt_everything_btn_click, components, [everything_image], queue=False)
if __name__ == '__main__':
block.queue()
block.launch()