Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
import spaces
|
2 |
import multiprocessing as mp
|
3 |
import numpy as np
|
@@ -18,6 +22,17 @@ import gradio as gr
|
|
18 |
import open_clip
|
19 |
from sam2.build_sam import build_sam2
|
20 |
from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
|
@@ -93,7 +108,7 @@ def inference_automatic(input_img, class_names):
|
|
93 |
@spaces.GPU
|
94 |
@torch.no_grad()
|
95 |
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
96 |
-
def inference_point(input_img,
|
97 |
|
98 |
|
99 |
mp.set_start_method("spawn", force=True)
|
@@ -106,8 +121,20 @@ def inference_point(input_img, img_state,):
|
|
106 |
|
107 |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
108 |
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
return visualized_output
|
112 |
|
113 |
|
@@ -136,8 +163,20 @@ def inference_box(input_img, img_state,):
|
|
136 |
|
137 |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
138 |
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
return visualized_output
|
142 |
|
143 |
|
@@ -234,7 +273,7 @@ def preprocess_example(input_img, img_state):
|
|
234 |
|
235 |
def clear_everything(img_state):
|
236 |
img_state.clear()
|
237 |
-
return img_state, None, None
|
238 |
|
239 |
|
240 |
def clean_prompts(img_state):
|
@@ -296,7 +335,7 @@ with gr.Blocks() as demo:
|
|
296 |
output_image = gr.Image(type="pil", label='Segmentation Map')
|
297 |
|
298 |
# Buttons below segmentation map (now placed under segmentation map)
|
299 |
-
run_button = gr.Button("Run Automatic Segmentation")
|
300 |
run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image)
|
301 |
|
302 |
clear_button = gr.Button("Clear")
|
@@ -310,9 +349,12 @@ with gr.Blocks() as demo:
|
|
310 |
with gr.Row(): # 水平排列
|
311 |
with gr.Column(scale=1):
|
312 |
input_image = gr.Image( label="Input Image", type="pil")
|
313 |
-
|
|
|
314 |
output_image_box = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
|
315 |
-
|
|
|
|
|
316 |
gr.Markdown("Click the top-left and bottom-right corners of the image to select a rectangular area")
|
317 |
|
318 |
input_image.select(
|
@@ -321,30 +363,31 @@ with gr.Blocks() as demo:
|
|
321 |
outputs=[img_state_bbox, input_image]
|
322 |
).then(
|
323 |
inference_box,
|
324 |
-
inputs=[input_image, img_state_bbox],
|
325 |
outputs=[output_image_box]
|
326 |
)
|
327 |
-
|
|
|
328 |
clear_prompt_button_box.click(
|
329 |
clean_prompts,
|
330 |
inputs=[img_state_bbox],
|
331 |
outputs=[img_state_bbox, input_image, output_image_box]
|
332 |
)
|
333 |
-
|
334 |
clear_button_box.click(
|
335 |
clear_everything,
|
336 |
inputs=[img_state_bbox],
|
337 |
-
outputs=[img_state_bbox, input_image, output_image_box]
|
338 |
)
|
339 |
input_image.clear(
|
340 |
clear_everything,
|
341 |
inputs=[img_state_bbox],
|
342 |
-
outputs=[img_state_bbox, input_image, output_image_box]
|
343 |
)
|
344 |
output_image_box.clear(
|
345 |
clear_everything,
|
346 |
inputs=[img_state_bbox],
|
347 |
-
outputs=[img_state_bbox, input_image, output_image_box]
|
348 |
)
|
349 |
|
350 |
|
@@ -363,44 +406,41 @@ with gr.Blocks() as demo:
|
|
363 |
with gr.Row(): # 水平排列
|
364 |
with gr.Column(scale=1):
|
365 |
input_image = gr.Image( label="Input Image", type="pil")
|
366 |
-
|
|
|
367 |
output_image_point = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
|
368 |
-
|
|
|
|
|
369 |
input_image.select(
|
370 |
get_points_with_draw,
|
371 |
[input_image, img_state_points],
|
372 |
outputs=[img_state_points, input_image]
|
373 |
).then(
|
374 |
inference_point,
|
375 |
-
inputs=[input_image, img_state_points],
|
376 |
outputs=[output_image_point]
|
377 |
)
|
378 |
-
clear_prompt_button_point = gr.Button("Clean Prompt")
|
379 |
clear_prompt_button_point.click(
|
380 |
clean_prompts,
|
381 |
inputs=[img_state_points],
|
382 |
outputs=[img_state_points, input_image, output_image_point]
|
383 |
)
|
384 |
-
clear_button_point = gr.Button("Restart")
|
385 |
clear_button_point.click(
|
386 |
clear_everything,
|
387 |
inputs=[img_state_points],
|
388 |
-
outputs=[img_state_points, input_image, output_image_point]
|
389 |
)
|
390 |
input_image.clear(
|
391 |
clear_everything,
|
392 |
inputs=[img_state_points],
|
393 |
-
outputs=[img_state_points, input_image, output_image_point]
|
394 |
)
|
395 |
output_image_point.clear(
|
396 |
clear_everything,
|
397 |
inputs=[img_state_points],
|
398 |
-
outputs=[img_state_points, input_image, output_image_point]
|
399 |
)
|
400 |
-
def clear_and_set_example_point(example):
|
401 |
-
clear_everything(img_state_points)
|
402 |
-
return example
|
403 |
-
|
404 |
gr.Examples(
|
405 |
examples=examples_point,
|
406 |
inputs=[input_image, img_state_points],
|
|
|
1 |
+
## Some code was modified from Ovseg and OV-Sam.Thanks to their excellent work.
|
2 |
+
## Ovseg Code:https://github.com/facebookresearch/ov-seg
|
3 |
+
## OV-Sam Code:https://github.com/HarborYuan/ovsam
|
4 |
+
|
5 |
import spaces
|
6 |
import multiprocessing as mp
|
7 |
import numpy as np
|
|
|
22 |
import open_clip
|
23 |
from sam2.build_sam import build_sam2
|
24 |
from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter
|
25 |
+
from mask_adapter.data.datasets import openseg_classes
|
26 |
+
|
27 |
+
COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
|
28 |
+
thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
|
29 |
+
stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
|
30 |
+
ADE20K_150_CATEGORIES_ = openseg_classes.get_ade20k_categories_with_prompt_eng()
|
31 |
+
ade20k_thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES_ if k["isthing"] == 1]
|
32 |
+
ade20k_stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES_]
|
33 |
+
class_names_coco_ade20k = thing_classes + stuff_classes + ade20k_thing_classes+ ade20k_stuff_classes
|
34 |
+
|
35 |
+
|
36 |
|
37 |
|
38 |
|
|
|
108 |
@spaces.GPU
|
109 |
@torch.no_grad()
|
110 |
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
111 |
+
def inference_point(input_img, img_state,class_names_input):
|
112 |
|
113 |
|
114 |
mp.set_start_method("spawn", force=True)
|
|
|
121 |
|
122 |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
123 |
|
124 |
+
if not class_names_input:
|
125 |
+
class_names_input = class_names_coco_ade20k
|
126 |
+
|
127 |
+
if class_names_input == class_names_coco_ade20k:
|
128 |
+
text_features = torch.from_numpy(np.load("./text_embedding/coco_ade20k_text_embedding.npy")).cuda()
|
129 |
+
_, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features)
|
130 |
+
else:
|
131 |
+
class_names_input = class_names_input.split(',')
|
132 |
+
txts = [f'a photo of {cls_name}' for cls_name in class_names_input]
|
133 |
+
text = open_clip.tokenize(txts)
|
134 |
+
text_features = clip_model.encode_text(text.cuda())
|
135 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
136 |
+
_, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features,class_names_input)
|
137 |
+
|
138 |
return visualized_output
|
139 |
|
140 |
|
|
|
163 |
|
164 |
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
165 |
|
166 |
+
if not class_names_input:
|
167 |
+
class_names_input = class_names_coco_ade20k
|
168 |
+
|
169 |
+
if class_names_input == class_names_coco_ade20k:
|
170 |
+
text_features = torch.from_numpy(np.load("./text_embedding/coco_ade20k_text_embedding.npy")).cuda()
|
171 |
+
_, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features)
|
172 |
+
else:
|
173 |
+
class_names_input = class_names_input.split(',')
|
174 |
+
txts = [f'a photo of {cls_name}' for cls_name in class_names_input]
|
175 |
+
text = open_clip.tokenize(txts)
|
176 |
+
text_features = clip_model.encode_text(text.cuda())
|
177 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
178 |
+
_, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features,class_names_input)
|
179 |
+
|
180 |
return visualized_output
|
181 |
|
182 |
|
|
|
273 |
|
274 |
def clear_everything(img_state):
|
275 |
img_state.clear()
|
276 |
+
return img_state, None, None, gr.Textbox(value='',lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
|
277 |
|
278 |
|
279 |
def clean_prompts(img_state):
|
|
|
335 |
output_image = gr.Image(type="pil", label='Segmentation Map')
|
336 |
|
337 |
# Buttons below segmentation map (now placed under segmentation map)
|
338 |
+
run_button = gr.Button("Run Automatic Segmentation", elem_id="run_button",variant='primary')
|
339 |
run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image)
|
340 |
|
341 |
clear_button = gr.Button("Clear")
|
|
|
349 |
with gr.Row(): # 水平排列
|
350 |
with gr.Column(scale=1):
|
351 |
input_image = gr.Image( label="Input Image", type="pil")
|
352 |
+
class_names_input_box = gr.Textbox(lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
|
353 |
+
with gr.Column(scale=1):
|
354 |
output_image_box = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
|
355 |
+
clear_prompt_button_box = gr.Button("Clean Prompt")
|
356 |
+
clear_button_box = gr.Button("Restart")
|
357 |
+
|
358 |
gr.Markdown("Click the top-left and bottom-right corners of the image to select a rectangular area")
|
359 |
|
360 |
input_image.select(
|
|
|
363 |
outputs=[img_state_bbox, input_image]
|
364 |
).then(
|
365 |
inference_box,
|
366 |
+
inputs=[input_image, img_state_bbox,class_names_input_box],
|
367 |
outputs=[output_image_box]
|
368 |
)
|
369 |
+
|
370 |
+
|
371 |
clear_prompt_button_box.click(
|
372 |
clean_prompts,
|
373 |
inputs=[img_state_bbox],
|
374 |
outputs=[img_state_bbox, input_image, output_image_box]
|
375 |
)
|
376 |
+
|
377 |
clear_button_box.click(
|
378 |
clear_everything,
|
379 |
inputs=[img_state_bbox],
|
380 |
+
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
|
381 |
)
|
382 |
input_image.clear(
|
383 |
clear_everything,
|
384 |
inputs=[img_state_bbox],
|
385 |
+
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
|
386 |
)
|
387 |
output_image_box.clear(
|
388 |
clear_everything,
|
389 |
inputs=[img_state_bbox],
|
390 |
+
outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
|
391 |
)
|
392 |
|
393 |
|
|
|
406 |
with gr.Row(): # 水平排列
|
407 |
with gr.Column(scale=1):
|
408 |
input_image = gr.Image( label="Input Image", type="pil")
|
409 |
+
class_names_input_point = gr.Textbox(lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
|
410 |
+
with gr.Column(scale=1):
|
411 |
output_image_point = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
|
412 |
+
clear_prompt_button_point = gr.Button("Clean Prompt")
|
413 |
+
clear_button_point = gr.Button("Restart")
|
414 |
+
|
415 |
input_image.select(
|
416 |
get_points_with_draw,
|
417 |
[input_image, img_state_points],
|
418 |
outputs=[img_state_points, input_image]
|
419 |
).then(
|
420 |
inference_point,
|
421 |
+
inputs=[input_image, img_state_points,class_names_input_point],
|
422 |
outputs=[output_image_point]
|
423 |
)
|
|
|
424 |
clear_prompt_button_point.click(
|
425 |
clean_prompts,
|
426 |
inputs=[img_state_points],
|
427 |
outputs=[img_state_points, input_image, output_image_point]
|
428 |
)
|
|
|
429 |
clear_button_point.click(
|
430 |
clear_everything,
|
431 |
inputs=[img_state_points],
|
432 |
+
outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
|
433 |
)
|
434 |
input_image.clear(
|
435 |
clear_everything,
|
436 |
inputs=[img_state_points],
|
437 |
+
outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
|
438 |
)
|
439 |
output_image_point.clear(
|
440 |
clear_everything,
|
441 |
inputs=[img_state_points],
|
442 |
+
outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
|
443 |
)
|
|
|
|
|
|
|
|
|
444 |
gr.Examples(
|
445 |
examples=examples_point,
|
446 |
inputs=[input_image, img_state_points],
|