Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
from inference import SegmentPredictor, DepthPredictor | |
from utils import generate_PCL, PCL3, point_cloud | |
sam = SegmentPredictor() | |
dpt = DepthPredictor() | |
red = (255,0,0) | |
blue = (0,0,255) | |
annos = [] | |
block = gr.Blocks() | |
with block: | |
# States | |
def point_coords_empty(): | |
return [] | |
def point_labels_empty(): | |
return [] | |
image_edit_trigger = gr.State(True) | |
point_coords = gr.State(point_coords_empty) | |
point_labels = gr.State(point_labels_empty) | |
masks = gr.State([]) | |
cutout_idx = gr.State(set()) | |
pred_masks = gr.State([]) | |
prompt_masks = gr.State([]) | |
# UI | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False | |
with gr.Row(): | |
sam_encode_btn = gr.Button('Encode', variant='primary') | |
sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary') | |
#sam_encode_status = gr.Label('Not encoded yet') | |
with gr.Row(): | |
prompt_image = gr.Image(label='Segments') | |
#prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels') | |
lbl_image = gr.AnnotatedImage(label='Everything') | |
with gr.Row(): | |
point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1) | |
text = gr.Textbox(label='Mask Name') | |
reset_btn = gr.Button('New Mask') | |
selected_masks_image = gr.AnnotatedImage(label='Selected Masks') | |
with gr.Row(): | |
with gr.Column(): | |
pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0]) | |
with gr.Row(): | |
n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples') | |
cube_size = gr.Slider(minimum=0.000001, maximum=0.001, step=0.000001, default=0.00001, label='Cube size') | |
depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary') | |
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary') | |
# components | |
components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image, | |
point_label_radio, text, reset_btn, sam_sgmt_everything_btn, | |
sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, cube_size, selected_masks_image} | |
# event - init coords | |
def on_reset_btn_click(input_image): | |
return input_image, point_coords_empty(), point_labels_empty(), None, [] | |
reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False) | |
def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, evt: gr.SelectData): | |
x, y = evt.index | |
color = red if point_label_radio == 0 else blue | |
if prompt_image is None: | |
prompt_image = np.array(input_image.copy()) | |
cv2.circle(prompt_image, (x, y), 5, color, -1) | |
point_coords.append([x,y]) | |
point_labels.append(point_label_radio) | |
sam_masks = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels)) | |
return [ prompt_image, | |
(input_image, sam_masks), | |
point_coords, | |
point_labels, | |
sam_masks ] | |
prompt_image.select(on_prompt_image_select, | |
[input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks], | |
[prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=False) | |
def on_everything_image_select(input_image, pred_masks, masks, text, evt: gr.SelectData): | |
i = evt.index | |
mask = pred_masks[i][0] | |
print(mask) | |
print(type(mask)) | |
masks.append((mask, text)) | |
anno = (input_image, masks) | |
return [masks, anno] | |
lbl_image.select(on_everything_image_select, | |
[input_image, pred_masks, masks, text], | |
[masks, selected_masks_image], queue=False) | |
#prompt_lbl_image.select(on_everything_image_select, | |
# [input_image, prompt_masks, masks, text], | |
# [masks, selected_masks_image], queue=False) | |
def on_click_sam_encode_btn(inputs): | |
print("encoding") | |
# encode image on click | |
sam.encode(inputs[input_image]) | |
print("encoding done") | |
return {prompt_image: inputs[input_image]} | |
sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image], queue=False) | |
def on_click_sam_dencode_btn(inputs): | |
print("inferencing") | |
image = inputs[input_image] | |
generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels])) | |
inputs[masks].append((generated_mask, inputs[text])) | |
print(inputs[masks][0]) | |
return {prompt_image: (image, inputs[masks])} | |
sam_decode_btn.click(on_click_sam_dencode_btn, components, [prompt_image, masks, cutout_idx], queue=True) | |
def on_depth_reconstruction_btn_click(inputs): | |
print("depth reconstruction") | |
image = inputs[input_image] | |
path = dpt.generate_obj_rgb(image=image, n_samples=inputs[n_samples], cube_size=inputs[cube_size]) # | |
return {pcl_figure: path} | |
depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False) | |
def on_sam_sgmt_everything_btn_click(inputs): | |
print("segmenting everything") | |
image = inputs[input_image] | |
sam_masks = sam.segment_everything(image) | |
print(image) | |
print(sam_masks) | |
return [(image, sam_masks), sam_masks] | |
sam_sgmt_everything_btn.click(on_sam_sgmt_everything_btn_click, components, [lbl_image, pred_masks], queue=False) | |
if __name__ == '__main__': | |
block.queue() | |
block.launch() |