JensParslov's picture
Duplicate from NN-BRD/hackathon_depth_segment
b3da277
import os
import gradio as gr
import numpy as np
import cv2
from PIL import Image, ImageOps
import torch
from inference import SegmentPredictor, DepthPredictor
from utils import generate_PCL, PCL3, point_cloud
sam = SegmentPredictor()
sam_cpu = SegmentPredictor(device="cpu")
dpt = DepthPredictor()
red = (255, 0, 0)
blue = (0, 0, 255)
annos = []
block = gr.Blocks()
with block:
# States
def point_coords_empty():
return []
def point_labels_empty():
return []
image_edit_trigger = gr.State(True)
point_coords = gr.State(point_coords_empty)
point_labels = gr.State(point_labels_empty)
masks = gr.State([])
cutout_idx = gr.State(set())
pred_masks = gr.State([])
prompt_masks = gr.State([])
embedding = gr.State()
# UI
with gr.Column():
gr.Markdown(
"""# Segment Anything Model (SAM)
## a new AI model from Meta AI that can "cut out" any object, in any image, with a single click 🚀
SAM is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training. [**Official Project**](https://segment-anything.com/) [**Code**](https://github.com/facebookresearch/segment-anything).
"""
)
with gr.Row():
with gr.Column():
with gr.Tab("Upload Image"):
# mirror_webcam = False
upload_image = gr.Image(label="Input", type="pil", tool=None)
with gr.Tab("Webcam"):
# mirror_webcam = False
input_image = gr.Image(
label="Input", type="pil", tool=None, source="webcam"
)
with gr.Row():
sam_encode_btn = gr.Button("Encode", variant="primary")
sam_sgmt_everything_btn = gr.Button(
"Segment Everything!", variant="primary"
)
# sam_encode_status = gr.Label('Not encoded yet')
with gr.Row():
prompt_image = gr.Image(label="Segments")
# prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
lbl_image = gr.AnnotatedImage(label="Everything")
with gr.Row():
point_label_radio = gr.Radio(label="Point Label", choices=[1, 0], value=1)
text = gr.Textbox(label="Mask Name")
reset_btn = gr.Button("New Mask")
selected_masks_image = gr.AnnotatedImage(label="Selected Masks")
with gr.Row():
with gr.Column():
pcl_figure = gr.Model3D(
label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0]
)
with gr.Row():
max_depth = gr.Slider(
minimum=0, maximum=10, value=3, step=0.01, label="Max Depth"
)
min_depth = gr.Slider(
minimum=0, maximum=10, step=0.01, value=1, label="Min Depth"
)
n_samples = gr.Slider(
minimum=1e3,
maximum=1e6,
step=1e3,
value=1e5,
label="Number of Samples",
)
cube_size = gr.Slider(
minimum=0.00001,
maximum=0.001,
step=0.000001,
default=0.00001,
label="Cube size",
)
depth_reconstruction_btn = gr.Button(
"3D Reconstruction", variant="primary"
)
depth_reconstruction_mask_btn = gr.Button(
"Mask Reconstruction", variant="primary"
)
sam_decode_btn = gr.Button("Predict using points!", variant="primary")
# components
components = {
point_coords,
point_labels,
image_edit_trigger,
masks,
cutout_idx,
input_image,
embedding,
point_label_radio,
text,
reset_btn,
sam_sgmt_everything_btn,
sam_decode_btn,
depth_reconstruction_btn,
prompt_image,
lbl_image,
n_samples,
max_depth,
min_depth,
cube_size,
selected_masks_image,
}
def on_upload_image(input_image, upload_image):
# Mirror because gradio.image webcam has mirror = True
upload_image_mirror = ImageOps.mirror(upload_image)
return [upload_image_mirror, upload_image]
upload_image.upload(
on_upload_image, [input_image, upload_image], [input_image, upload_image]
)
# event - init coords
def on_reset_btn_click(input_image):
return input_image, point_coords_empty(), point_labels_empty(), None, []
reset_btn.click(
on_reset_btn_click,
[input_image],
[prompt_image, point_coords, point_labels],
queue=False,
)
def on_prompt_image_select(
input_image,
prompt_image,
point_coords,
point_labels,
point_label_radio,
text,
pred_masks,
embedding,
evt: gr.SelectData,
):
sam_cpu.dummy_encode(input_image)
x, y = evt.index
color = red if point_label_radio == 0 else blue
if prompt_image is None:
prompt_image = np.array(input_image.copy())
cv2.circle(prompt_image, (x, y), 5, color, -1)
point_coords.append([x, y])
point_labels.append(point_label_radio)
sam_masks = sam_cpu.cond_pred(
pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding
)
return [
prompt_image,
(input_image, sam_masks),
point_coords,
point_labels,
sam_masks,
]
prompt_image.select(
on_prompt_image_select,
[
input_image,
prompt_image,
point_coords,
point_labels,
point_label_radio,
text,
pred_masks,
embedding,
],
[prompt_image, lbl_image, point_coords, point_labels, pred_masks],
queue=True,
)
def on_everything_image_select(
input_image, pred_masks, masks, text, evt: gr.SelectData
):
i = evt.index
mask = pred_masks[i][0]
print(mask)
print(type(mask))
masks.append((mask, text))
anno = (input_image, masks)
return [masks, anno]
lbl_image.select(
on_everything_image_select,
[input_image, pred_masks, masks, text],
[masks, selected_masks_image],
queue=False,
)
def on_selected_masks_image_select(input_image, masks, evt: gr.SelectData):
i = evt.index
del masks[i]
anno = (input_image, masks)
return [masks, anno]
selected_masks_image.select(
on_selected_masks_image_select,
[input_image, masks],
[masks, selected_masks_image],
queue=False,
)
# prompt_lbl_image.select(on_everything_image_select,
# [input_image, prompt_masks, masks, text],
# [masks, selected_masks_image], queue=False)
def on_click_sam_encode_btn(inputs):
print("encoding")
# encode image on click
embedding = sam.encode(inputs[input_image]).cpu()
sam_cpu.dummy_encode(inputs[input_image])
print("encoding done")
return [inputs[input_image], embedding]
sam_encode_btn.click(
on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False
)
def on_click_sam_dencode_btn(inputs):
print("inferencing")
image = inputs[input_image]
generated_mask, _, _ = sam.cond_pred(
pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels])
)
inputs[masks].append((generated_mask, inputs[text]))
print(inputs[masks][0])
return {prompt_image: (image, inputs[masks])}
sam_decode_btn.click(
on_click_sam_dencode_btn,
components,
[prompt_image, masks, cutout_idx],
queue=True,
)
def on_depth_reconstruction_btn_click(inputs):
print("depth reconstruction")
path = dpt.generate_obj_rgb(
image=inputs[input_image],
cube_size=inputs[cube_size],
n_samples=inputs[n_samples],
# masks=inputs[masks],
min_depth=inputs[min_depth],
max_depth=inputs[max_depth],
)
return {pcl_figure: path}
depth_reconstruction_btn.click(
on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False
)
def on_depth_reconstruction_mask_btn_click(inputs):
print("depth reconstruction")
path = dpt.generate_obj_masks2(
image=inputs[input_image],
cube_size=inputs[cube_size],
n_samples=inputs[n_samples],
masks=inputs[masks],
min_depth=inputs[min_depth],
max_depth=inputs[max_depth],
)
return {pcl_figure: path}
depth_reconstruction_mask_btn.click(
on_depth_reconstruction_mask_btn_click, components, [pcl_figure], queue=False
)
def on_sam_sgmt_everything_btn_click(inputs):
print("segmenting everything")
image = inputs[input_image]
sam_masks = sam.segment_everything(image)
print(image)
print(sam_masks)
return [(image, sam_masks), sam_masks]
sam_sgmt_everything_btn.click(
on_sam_sgmt_everything_btn_click,
components,
[lbl_image, pred_masks],
queue=True,
)
if __name__ == "__main__":
block.queue()
block.launch(auth=("novouser", "bstad2023"))