Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
from inference import SegmentPredictor | |
from utils import generate_PCL, PCL3 | |
sam = SegmentPredictor() | |
red = (255,0,0) | |
blue = (0,0,255) | |
annos = [] | |
block = gr.Blocks() | |
with block: | |
# States | |
def point_coords_empty(): | |
return [] | |
def point_labels_empty(): | |
return [] | |
raw_image = gr.Image(type='pil', visible=False) | |
point_coords = gr.State(point_coords_empty) | |
point_labels = gr.State(point_labels_empty) | |
masks = gr.State([]) | |
cutout_idx = gr.State(set()) | |
# UI | |
with gr.Column(): | |
with gr.Row(): | |
input_image = gr.Image(label='Input', height=512, type='pil') | |
masks_annotated_image = gr.AnnotatedImage(label='Segments', height=512) | |
pcl_figure = gr.Plot(label='3D Reconstruction') | |
#cutout_galary = gr.Gallery(label='Cutouts', object_fit='contain', height=512) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1) | |
text = gr.Textbox(label='Mask Name') | |
reset_btn = gr.Button('New Mask') | |
sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary') | |
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary') | |
depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary') | |
# components | |
components = {point_coords, point_labels, raw_image, masks, cutout_idx, input_image, | |
point_label_radio, text, reset_btn, sam_sgmt_everything_btn, | |
sam_decode_btn, depth_reconstruction_btn, masks_annotated_image} | |
def on_depth_reconstruction_btn_click(inputs): | |
print("depth reconstruction") | |
image = inputs[raw_image] | |
# depth reconstruction | |
fig = PCL3(image) | |
return {pcl_figure: fig} | |
depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False) | |
# event - init coords | |
def on_reset_btn_click(raw_image): | |
return raw_image, point_coords_empty(), point_labels_empty(), None, [] | |
reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False) | |
def on_input_image_upload(input_image): | |
print("encoding") | |
# encode image on upload | |
sam.encode(input_image) | |
print("encoding done") | |
return input_image, point_coords_empty(), point_labels_empty(), None | |
input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False) | |
# event - set coords | |
def on_input_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData): | |
x, y = evt.index | |
color = red if point_label_radio == 0 else blue | |
img = np.array(input_image) | |
cv2.circle(img, (x, y), 5, color, -1) | |
img = Image.fromarray(img) | |
point_coords.append([x,y]) | |
point_labels.append(point_label_radio) | |
return img, point_coords, point_labels | |
input_image.select(on_input_image_select, [input_image, point_coords, point_labels, point_label_radio], [input_image, point_coords, point_labels], queue=False) | |
def on_click_sam_dencode_btn(inputs): | |
print("inferencing") | |
image = inputs[raw_image] | |
generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels])) | |
inputs[masks].append((generated_mask, inputs[text])) | |
return {masks_annotated_image: (image, inputs[masks])} | |
sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True) | |
if __name__ == '__main__': | |
block.queue() | |
block.launch() |