Spaces:
Runtime error
Runtime error
File size: 6,176 Bytes
077fc91 5c0b534 7598e8a 077fc91 01bc85d 769894a 5c0b534 077fc91 1689431 01bc85d 077fc91 d4233b7 077fc91 d46e73c 077fc91 f07135c 077fc91 4d6f971 185ceb1 077fc91 c02210d 1758fb9 4d6f971 60edd6a 9e6e225 4d6f971 a5fefd2 4d6f971 6046fb8 60edd6a 0579ca3 706546d 0b87dc6 0579ca3 60edd6a 4f4f67b 60edd6a 077fc91 0579ca3 d4233b7 a2c6e8a 1689431 077fc91 0579ca3 077fc91 4d6f971 077fc91 fe0db59 9e6e225 fe0db59 077fc91 711582a fe0db59 711582a 613f332 185ceb1 711582a 9e6e225 f76bf44 4d6f971 0a54901 4d6f971 0a54901 4d6f971 c1a5086 0a54901 9503ae0 0a54901 4d6f971 185ceb1 4d6f971 0a54901 077fc91 c02210d 60edd6a c02210d 077fc91 b3873d0 0579ca3 d4233b7 7299967 d46e73c 1c4f487 077fc91 5a6d6d4 0579ca3 17dfbdb 5a6d6d4 d7bd88e c810b3c ae29f3e c69e375 4d6f971 d7bd88e 077fc91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import os
import gradio as gr
import numpy as np
import cv2
from PIL import Image
import torch
from inference import SegmentPredictor, DepthPredictor
from utils import generate_PCL, PCL3, point_cloud
sam = SegmentPredictor()
dpt = DepthPredictor()
red = (255,0,0)
blue = (0,0,255)
annos = []
block = gr.Blocks()
with block:
# States
def point_coords_empty():
return []
def point_labels_empty():
return []
image_edit_trigger = gr.State(True)
point_coords = gr.State(point_coords_empty)
point_labels = gr.State(point_labels_empty)
masks = gr.State([])
cutout_idx = gr.State(set())
pred_masks = gr.State([])
prompt_masks = gr.State([])
# UI
with gr.Column():
with gr.Row():
with gr.Column():
input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False
with gr.Row():
sam_encode_btn = gr.Button('Encode', variant='primary')
sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary')
#sam_encode_status = gr.Label('Not encoded yet')
with gr.Row():
prompt_image = gr.Image(label='Segments')
#prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
lbl_image = gr.AnnotatedImage(label='Everything')
with gr.Row():
point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
text = gr.Textbox(label='Mask Name')
reset_btn = gr.Button('New Mask')
selected_masks_image = gr.AnnotatedImage(label='Selected Masks')
with gr.Row():
with gr.Column():
pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0])
with gr.Row():
n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples')
cube_size = gr.Slider(minimum=0.000001, maximum=0.001, step=0.000001, default=0.00001, label='Cube size')
depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary')
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
# components
components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image,
point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, cube_size, selected_masks_image}
# event - init coords
def on_reset_btn_click(input_image):
return input_image, point_coords_empty(), point_labels_empty(), None, []
reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, evt: gr.SelectData):
x, y = evt.index
color = red if point_label_radio == 0 else blue
if prompt_image is None:
prompt_image = np.array(input_image.copy())
cv2.circle(prompt_image, (x, y), 5, color, -1)
point_coords.append([x,y])
point_labels.append(point_label_radio)
sam_masks = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
return [ prompt_image,
(input_image, sam_masks),
point_coords,
point_labels,
sam_masks ]
prompt_image.select(on_prompt_image_select,
[input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks],
[prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=False)
def on_everything_image_select(input_image, pred_masks, masks, text, evt: gr.SelectData):
i = evt.index
mask = pred_masks[i][0]
print(mask)
print(type(mask))
masks.append((mask, text))
anno = (input_image, masks)
return [masks, anno]
lbl_image.select(on_everything_image_select,
[input_image, pred_masks, masks, text],
[masks, selected_masks_image], queue=False)
#prompt_lbl_image.select(on_everything_image_select,
# [input_image, prompt_masks, masks, text],
# [masks, selected_masks_image], queue=False)
def on_click_sam_encode_btn(inputs):
print("encoding")
# encode image on click
sam.encode(inputs[input_image])
print("encoding done")
return {prompt_image: inputs[input_image]}
sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image], queue=False)
def on_click_sam_dencode_btn(inputs):
print("inferencing")
image = inputs[input_image]
generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
inputs[masks].append((generated_mask, inputs[text]))
print(inputs[masks][0])
return {prompt_image: (image, inputs[masks])}
sam_decode_btn.click(on_click_sam_dencode_btn, components, [prompt_image, masks, cutout_idx], queue=True)
def on_depth_reconstruction_btn_click(inputs):
print("depth reconstruction")
image = inputs[input_image]
path = dpt.generate_obj_rgb(image=image, n_samples=inputs[n_samples], cube_size=inputs[cube_size]) #
return {pcl_figure: path}
depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False)
def on_sam_sgmt_everything_btn_click(inputs):
print("segmenting everything")
image = inputs[input_image]
sam_masks = sam.segment_everything(image)
print(image)
print(sam_masks)
return [(image, sam_masks), sam_masks]
sam_sgmt_everything_btn.click(on_sam_sgmt_everything_btn_click, components, [lbl_image, pred_masks], queue=False)
if __name__ == '__main__':
block.queue()
block.launch() |