Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
from inference import SegmentPredictor, DepthPredictor | |
from utils import generate_PCL, PCL3, point_cloud | |
sam = SegmentPredictor() | |
dpt = DepthPredictor() | |
red = (255,0,0) | |
blue = (0,0,255) | |
annos = [] | |
block = gr.Blocks() | |
with block: | |
# States | |
def point_coords_empty(): | |
return [] | |
def point_labels_empty(): | |
return [] | |
image_edit_trigger = gr.State(True) | |
point_coords = gr.State(point_coords_empty) | |
point_labels = gr.State(point_labels_empty) | |
masks = gr.State([]) | |
cutout_idx = gr.State(set()) | |
pred_masks = gr.State([]) | |
prompt_masks = gr.State([]) | |
# UI | |
with gr.Column(): | |
gr.Markdown( | |
'''# Segment Anything!π | |
The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. More information can be found in [**Official Project**](https://segment-anything.com/). | |
[](https://huggingface.co/spaces/AIBoy1993/segment_anything_webui?duplicate=true) | |
''' | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False | |
with gr.Row(): | |
sam_encode_btn = gr.Button('Encode', variant='primary') | |
sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary') | |
#sam_encode_status = gr.Label('Not encoded yet') | |
with gr.Row(): | |
prompt_image = gr.Image(label='Segments') | |
#prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels') | |
lbl_image = gr.AnnotatedImage(label='Everything') | |
with gr.Row(): | |
point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1) | |
text = gr.Textbox(label='Mask Name') | |
reset_btn = gr.Button('New Mask') | |
selected_masks_image = gr.AnnotatedImage(label='Selected Masks') | |
with gr.Row(): | |
with gr.Column(): | |
pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0]) | |
with gr.Row(): | |
n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples') | |
cube_size = gr.Slider(minimum=0.000001, maximum=0.001, step=0.000001, default=0.00001, label='Cube size') | |
depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary') | |
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary') | |
# components | |
components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image, | |
point_label_radio, text, reset_btn, sam_sgmt_everything_btn, | |
sam_decode_btn, depth_reconstruction_btn, prompt_image, lbl_image, n_samples, cube_size, selected_masks_image} | |
# event - init coords | |
def on_reset_btn_click(input_image): | |
return input_image, point_coords_empty(), point_labels_empty(), None, [] | |
reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False) | |
def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, evt: gr.SelectData): | |
x, y = evt.index | |
color = red if point_label_radio == 0 else blue | |
if prompt_image is None: | |
prompt_image = np.array(input_image.copy()) | |
cv2.circle(prompt_image, (x, y), 5, color, -1) | |
point_coords.append([x,y]) | |
point_labels.append(point_label_radio) | |
sam_masks = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels)) | |
return [ prompt_image, | |
(input_image, sam_masks), | |
point_coords, | |
point_labels, | |
sam_masks ] | |
prompt_image.select(on_prompt_image_select, | |
[input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks], | |
[prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=False) | |
def on_everything_image_select(input_image, pred_masks, masks, text, evt: gr.SelectData): | |
i = evt.index | |
mask = pred_masks[i][0] | |
print(mask) | |
print(type(mask)) | |
masks.append((mask, text)) | |
anno = (input_image, masks) | |
return [masks, anno] | |
lbl_image.select(on_everything_image_select, | |
[input_image, pred_masks, masks, text], | |
[masks, selected_masks_image], queue=False) | |
#prompt_lbl_image.select(on_everything_image_select, | |
# [input_image, prompt_masks, masks, text], | |
# [masks, selected_masks_image], queue=False) | |
def on_click_sam_encode_btn(inputs): | |
print("encoding") | |
# encode image on click | |
sam.encode(inputs[input_image]) | |
print("encoding done") | |
return {prompt_image: inputs[input_image]} | |
sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image], queue=False) | |
def on_click_sam_dencode_btn(inputs): | |
print("inferencing") | |
image = inputs[input_image] | |
generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels])) | |
inputs[masks].append((generated_mask, inputs[text])) | |
print(inputs[masks][0]) | |
return {prompt_image: (image, inputs[masks])} | |
sam_decode_btn.click(on_click_sam_dencode_btn, components, [prompt_image, masks, cutout_idx], queue=True) | |
def on_depth_reconstruction_btn_click(inputs): | |
print("depth reconstruction") | |
image = inputs[input_image] | |
path = dpt.generate_obj_masks2(image=image, cube_size=inputs[cube_size], masks=masks) # | |
return {pcl_figure: path} | |
depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False) | |
def on_sam_sgmt_everything_btn_click(inputs): | |
print("segmenting everything") | |
image = inputs[input_image] | |
sam_masks = sam.segment_everything(image) | |
print(image) | |
print(sam_masks) | |
return [(image, sam_masks), sam_masks] | |
sam_sgmt_everything_btn.click(on_sam_sgmt_everything_btn_click, components, [lbl_image, pred_masks], queue=False) | |
if __name__ == '__main__': | |
block.queue() | |
block.launch() |