Spaces:
Runtime error
Runtime error
File size: 9,251 Bytes
077fc91 5c0b534 7598e8a 9f0c4b3 c93bd34 077fc91 01bc85d 769894a 5c0b534 077fc91 1689431 9f0c4b3 01bc85d 9f0c4b3 d4233b7 077fc91 9f0c4b3 077fc91 9f0c4b3 d46e73c 077fc91 f07135c 077fc91 4d6f971 185ceb1 640f5b4 077fc91 daeee36 9f0c4b3 0457b5c 9f0c4b3 daeee36 077fc91 c02210d 43dcd18 9f0c4b3 43dcd18 9f0c4b3 4d6f971 9f0c4b3 9e6e225 9f0c4b3 6046fb8 9f0c4b3 0579ca3 706546d 9f0c4b3 0579ca3 9f0c4b3 60edd6a 9f0c4b3 60edd6a 077fc91 9f0c4b3 2f5a3a2 9f0c4b3 c93bd34 9f0c4b3 95eb778 077fc91 0579ca3 077fc91 9f0c4b3 8728327 077fc91 fe0db59 9e6e225 fe0db59 9f0c4b3 077fc91 9f0c4b3 0a54901 4d6f971 c1a5086 0a54901 9f0c4b3 9dd1448 9f0c4b3 4d6f971 0a54901 c02210d 640f5b4 39f3339 c02210d 640f5b4 9f0c4b3 c02210d 077fc91 b3873d0 0579ca3 9f0c4b3 7299967 d46e73c 1c4f487 9f0c4b3 077fc91 5a6d6d4 9f0c4b3 5a6d6d4 9f0c4b3 5a6d6d4 d7bd88e c810b3c ae29f3e c69e375 9f0c4b3 d7bd88e 077fc91 9f0c4b3 077fc91 9f0c4b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
import os
import gradio as gr
import numpy as np
import cv2
from PIL import Image, ImageOps
import torch
from inference import SegmentPredictor, DepthPredictor
from utils import generate_PCL, PCL3, point_cloud
sam = SegmentPredictor()
sam_cpu = SegmentPredictor(device="cpu")
dpt = DepthPredictor()
red = (255, 0, 0)
blue = (0, 0, 255)
annos = []
block = gr.Blocks()
with block:
# States
def point_coords_empty():
return []
def point_labels_empty():
return []
image_edit_trigger = gr.State(True)
point_coords = gr.State(point_coords_empty)
point_labels = gr.State(point_labels_empty)
masks = gr.State([])
cutout_idx = gr.State(set())
pred_masks = gr.State([])
prompt_masks = gr.State([])
embedding = gr.State()
# UI
with gr.Column():
gr.Markdown(
"""# Segment Anything Model (SAM)
## a new AI model from Meta AI that can "cut out" any object, in any image, with a single click π
SAM is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training. [**Official Project**](https://segment-anything.com/) [**Code**](https://github.com/facebookresearch/segment-anything).
"""
)
with gr.Row():
with gr.Column():
with gr.Tab("Upload Image"):
# mirror_webcam = False
upload_image = gr.Image(label="Input", type="pil", tool=None)
with gr.Tab("Webcam"):
# mirror_webcam = False
input_image = gr.Image(
label="Input", type="pil", tool=None, source="webcam"
)
with gr.Row():
sam_encode_btn = gr.Button("Encode", variant="primary")
sam_sgmt_everything_btn = gr.Button(
"Segment Everything!", variant="primary"
)
# sam_encode_status = gr.Label('Not encoded yet')
with gr.Row():
prompt_image = gr.Image(label="Segments")
# prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
lbl_image = gr.AnnotatedImage(label="Everything")
with gr.Row():
point_label_radio = gr.Radio(label="Point Label", choices=[1, 0], value=1)
text = gr.Textbox(label="Mask Name")
reset_btn = gr.Button("New Mask")
selected_masks_image = gr.AnnotatedImage(label="Selected Masks")
with gr.Row():
with gr.Column():
pcl_figure = gr.Model3D(
label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0]
)
with gr.Row():
max_depth = gr.Slider(
minimum=0, maximum=10, step=0.01, default=1, label="Max Depth"
)
min_depth = gr.Slider(
minimum=0, maximum=10, step=0.01, default=0.1, label="Min Depth"
)
n_samples = gr.Slider(
minimum=1e3,
maximum=1e6,
step=1e3,
default=1e3,
label="Number of Samples",
)
cube_size = gr.Slider(
minimum=0.00001,
maximum=0.001,
step=0.000001,
default=0.00001,
label="Cube size",
)
depth_reconstruction_btn = gr.Button(
"Depth Reconstruction", variant="primary"
)
sam_decode_btn = gr.Button("Predict using points!", variant="primary")
# components
components = {
point_coords,
point_labels,
image_edit_trigger,
masks,
cutout_idx,
input_image,
embedding,
point_label_radio,
text,
reset_btn,
sam_sgmt_everything_btn,
sam_decode_btn,
depth_reconstruction_btn,
prompt_image,
lbl_image,
n_samples,
max_depth,
min_depth,
cube_size,
selected_masks_image,
}
def on_upload_image(input_image, upload_image):
# Mirror because gradio.image webcam has mirror = True
upload_image_mirror = ImageOps.mirror(upload_image)
return [upload_image_mirror, upload_image]
upload_image.upload(
on_upload_image, [input_image, upload_image], [input_image, upload_image]
)
# event - init coords
def on_reset_btn_click(input_image):
return input_image, point_coords_empty(), point_labels_empty(), None, []
reset_btn.click(
on_reset_btn_click,
[input_image],
[input_image, point_coords, point_labels],
queue=False,
)
def on_prompt_image_select(
input_image,
prompt_image,
point_coords,
point_labels,
point_label_radio,
text,
pred_masks,
embedding,
evt: gr.SelectData,
):
sam_cpu.dummy_encode(input_image)
x, y = evt.index
color = red if point_label_radio == 0 else blue
if prompt_image is None:
prompt_image = np.array(input_image.copy())
cv2.circle(prompt_image, (x, y), 5, color, -1)
point_coords.append([x, y])
point_labels.append(point_label_radio)
sam_masks = sam_cpu.cond_pred(
pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding
)
return [
prompt_image,
(input_image, sam_masks),
point_coords,
point_labels,
sam_masks,
]
prompt_image.select(
on_prompt_image_select,
[
input_image,
prompt_image,
point_coords,
point_labels,
point_label_radio,
text,
pred_masks,
embedding,
],
[prompt_image, lbl_image, point_coords, point_labels, pred_masks],
queue=True,
)
def on_everything_image_select(
input_image, pred_masks, masks, text, evt: gr.SelectData
):
i = evt.index
mask = pred_masks[i][0]
print(mask)
print(type(mask))
masks.append((mask, text))
anno = (input_image, masks)
return [masks, anno]
lbl_image.select(
on_everything_image_select,
[input_image, pred_masks, masks, text],
[masks, selected_masks_image],
queue=False,
)
def on_selected_masks_image_select(input_image, masks, evt: gr.SelectData):
i = evt.index
del masks[i]
anno = (input_image, masks)
return [masks, anno]
selected_masks_image.select(
on_selected_masks_image_select,
[input_image, masks],
[masks, selected_masks_image],
queue=False,
)
# prompt_lbl_image.select(on_everything_image_select,
# [input_image, prompt_masks, masks, text],
# [masks, selected_masks_image], queue=False)
def on_click_sam_encode_btn(inputs):
print("encoding")
# encode image on click
embedding = sam.encode(inputs[input_image]).cpu()
sam_cpu.dummy_encode(inputs[input_image])
print("encoding done")
return [inputs[input_image], embedding]
sam_encode_btn.click(
on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False
)
def on_click_sam_dencode_btn(inputs):
print("inferencing")
image = inputs[input_image]
generated_mask, _, _ = sam.cond_pred(
pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels])
)
inputs[masks].append((generated_mask, inputs[text]))
print(inputs[masks][0])
return {prompt_image: (image, inputs[masks])}
sam_decode_btn.click(
on_click_sam_dencode_btn,
components,
[prompt_image, masks, cutout_idx],
queue=True,
)
def on_depth_reconstruction_btn_click(inputs):
print("depth reconstruction")
path = dpt.generate_obj_rgb(
image=inputs[input_image],
cube_size=inputs[cube_size],
n_samples=inputs[n_samples],
# masks=inputs[masks],
min_depth=inputs[min_depth],
max_depth=inputs[max_depth],
)
return {pcl_figure: path}
depth_reconstruction_btn.click(
on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False
)
def on_sam_sgmt_everything_btn_click(inputs):
print("segmenting everything")
image = inputs[input_image]
sam_masks = sam.segment_everything(image)
print(image)
print(sam_masks)
return [(image, sam_masks), sam_masks]
sam_sgmt_everything_btn.click(
on_sam_sgmt_everything_btn_click,
components,
[lbl_image, pred_masks],
queue=True,
)
if __name__ == "__main__":
block.queue()
block.launch()
|