Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
from inference import SegmentPredictor, DepthPredictor | |
from utils import generate_PCL, PCL3, point_cloud | |
sam = SegmentPredictor() | |
dpt = DepthPredictor() | |
red = (255,0,0) | |
blue = (0,0,255) | |
annos = [] | |
block = gr.Blocks() | |
with block: | |
# States | |
def point_coords_empty(): | |
return [] | |
def point_labels_empty(): | |
return [] | |
raw_image = gr.Image(type='pil', visible=False) | |
image_edit_trigger = gr.State(True) | |
point_coords = gr.State(point_coords_empty) | |
point_labels = gr.State(point_labels_empty) | |
masks = gr.State([]) | |
cutout_idx = gr.State(set()) | |
# UI | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False | |
sam_encode_btn = gr.Button('Encode', variant='primary') | |
sam_encode_status = gr.Label('Not encoded yet') | |
masks_annotated_image = gr.AnnotatedImage(label='Segments', height=512) | |
with gr.Column(): | |
pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0]) | |
n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples') | |
cube_size = gr.Slider(minimum=0.000001, maximum=0.001, step=0.000001, default=0.00001, label='Cube size') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1) | |
text = gr.Textbox(label='Mask Name') | |
reset_btn = gr.Button('New Mask') | |
sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary') | |
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary') | |
depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary') | |
# components | |
components = {point_coords, point_labels, raw_image, image_edit_trigger, masks, cutout_idx, input_image, | |
point_label_radio, text, reset_btn, sam_sgmt_everything_btn, | |
sam_decode_btn, depth_reconstruction_btn, masks_annotated_image, n_samples, cube_size} | |
# event - init coords | |
def on_reset_btn_click(raw_image): | |
return raw_image, point_coords_empty(), point_labels_empty(), None, [] | |
reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False) | |
def on_input_image_upload(input_image): | |
print("encoding") | |
# encode image on upload | |
sam.encode(input_image) | |
print("encoding done") | |
return input_image, point_coords_empty(), point_labels_empty(), None | |
input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False) | |
# event - set coords | |
def on_input_image_select(raw_image, input_image, image_edit_trigger, point_coords, point_labels, point_label_radio, evt: gr.SelectData): | |
if image_edit_trigger: | |
unedited_image = input_image.copy() | |
image_edit_trigger = False | |
else: | |
unedited_image = raw_image | |
x, y = evt.index | |
color = red if point_label_radio == 0 else blue | |
img = np.array(input_image) | |
cv2.circle(img, (x, y), 5, color, -1) | |
img = Image.fromarray(img) | |
point_coords.append([x,y]) | |
point_labels.append(point_label_radio) | |
return [unedited_image, | |
img, | |
point_coords, | |
point_labels, | |
image_edit_trigger] | |
input_image.select(on_input_image_select, | |
[raw_image, input_image, image_edit_trigger, point_coords, point_labels, point_label_radio], | |
[raw_image, input_image, point_coords, point_labels, image_edit_trigger], queue=False) | |
def on_click_sam_encode_btn(inputs): | |
print("encoding") | |
# encode image on click | |
sam.encode(inputs[input_image]) | |
print("encoding done") | |
return {sam_encode_status: 'Image Encoded!'} | |
sam_encode_btn.click(on_click_sam_encode_btn, components, [sam_encode_status], queue=False) | |
def on_click_sam_dencode_btn(inputs): | |
print("inferencing") | |
image = inputs[raw_image] | |
generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels])) | |
inputs[masks].append((generated_mask, inputs[text])) | |
print(inputs[masks][0]) | |
return {masks_annotated_image: (image, inputs[masks])} | |
sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True) | |
def on_depth_reconstruction_btn_click(inputs): | |
print("depth reconstruction") | |
image = inputs[raw_image] | |
path = dpt.generate_obj_masks(image=image, n_samples=inputs[n_samples], cube_size=inputs[cube_size], masks=inputs[masks]) | |
return {pcl_figure: path} | |
depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False) | |
if __name__ == '__main__': | |
block.queue() | |
block.launch() |