Spaces:
Runtime error
Runtime error
File size: 4,236 Bytes
077fc91 5c0b534 7598e8a 077fc91 01bc85d 769894a 5c0b534 077fc91 1689431 01bc85d 077fc91 d4233b7 077fc91 d0ed3bd 077fc91 f07135c 077fc91 76253eb 077fc91 effc523 706546d 077fc91 d4233b7 1689431 077fc91 4f4f67b 1689431 077fc91 d4233b7 706546d 1689431 077fc91 d4233b7 077fc91 d4233b7 077fc91 76253eb 077fc91 b3873d0 077fc91 d4233b7 7299967 8f90f14 077fc91 5a6d6d4 d777e23 706546d 5a6d6d4 077fc91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
from inference import SegmentPredictor, DepthPredictor
from utils import generate_PCL, PCL3, point_cloud
sam = SegmentPredictor()
dpt = DepthPredictor()
red = (255,0,0)
blue = (0,0,255)
annos = []
block = gr.Blocks()
with block:
# States
def point_coords_empty():
return []
def point_labels_empty():
return []
raw_image = gr.Image(type='pil', visible=False)
point_coords = gr.State(point_coords_empty)
point_labels = gr.State(point_labels_empty)
masks = gr.State([])
cutout_idx = gr.State(set())
# UI
with gr.Column():
with gr.Row():
input_image = gr.Image(label='Input', type='pil', source='webcam', tool=None)
masks_annotated_image = gr.AnnotatedImage(label='Segments', height=512)
pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0])
with gr.Column():
n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples')
cube_size = gr.Slider(minimum=0.000001, maximum=0.001, step=0.000001, default=0.00001, label='Cube size')
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
text = gr.Textbox(label='Mask Name')
reset_btn = gr.Button('New Mask')
sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary')
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary')
# components
components = {point_coords, point_labels, raw_image, masks, cutout_idx, input_image,
point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
sam_decode_btn, depth_reconstruction_btn, masks_annotated_image, n_samples, cube_size}
# event - init coords
def on_reset_btn_click(raw_image):
return raw_image, point_coords_empty(), point_labels_empty(), None, []
reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False)
def on_input_image_upload(input_image):
print("encoding")
# encode image on upload
sam.encode(input_image)
print("encoding done")
return input_image, point_coords_empty(), point_labels_empty(), None
input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
# event - set coords
def on_input_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
x, y = evt.index
color = red if point_label_radio == 0 else blue
img = np.array(input_image)
cv2.circle(img, (x, y), 5, color, -1)
img = Image.fromarray(img)
point_coords.append([x,y])
point_labels.append(point_label_radio)
return img, point_coords, point_labels
input_image.select(on_input_image_select, [input_image, point_coords, point_labels, point_label_radio], [input_image, point_coords, point_labels], queue=False)
def on_click_sam_dencode_btn(inputs):
print("inferencing")
image = inputs[raw_image]
generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
inputs[masks].append((generated_mask, inputs[text]))
return {masks_annotated_image: (image, inputs[masks])}
sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True)
def on_depth_reconstruction_btn_click(inputs):
print("depth reconstruction")
image = inputs[input_image]
path = dpt.generate_obj_masks(image, inputs[n_samples], inputs[cube_size])
return {pcl_figure: path}
depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False)
if __name__ == '__main__':
block.queue()
block.launch() |