wondervictor commited on
Commit
55249e4
·
verified ·
1 Parent(s): 6a560c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -27
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import spaces
2
  import multiprocessing as mp
3
  import numpy as np
@@ -18,6 +22,17 @@ import gradio as gr
18
  import open_clip
19
  from sam2.build_sam import build_sam2
20
  from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
 
@@ -93,7 +108,7 @@ def inference_automatic(input_img, class_names):
93
  @spaces.GPU
94
  @torch.no_grad()
95
  @torch.autocast(device_type="cuda", dtype=torch.float32)
96
- def inference_point(input_img, img_state,):
97
 
98
 
99
  mp.set_start_method("spawn", force=True)
@@ -106,8 +121,20 @@ def inference_point(input_img, img_state,):
106
 
107
  demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
108
 
109
- text_features = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy")).cuda()
110
- _, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features)
 
 
 
 
 
 
 
 
 
 
 
 
111
  return visualized_output
112
 
113
 
@@ -136,8 +163,20 @@ def inference_box(input_img, img_state,):
136
 
137
  demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
138
 
139
- text_features = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy")).cuda()
140
- _, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features)
 
 
 
 
 
 
 
 
 
 
 
 
141
  return visualized_output
142
 
143
 
@@ -234,7 +273,7 @@ def preprocess_example(input_img, img_state):
234
 
235
  def clear_everything(img_state):
236
  img_state.clear()
237
- return img_state, None, None
238
 
239
 
240
  def clean_prompts(img_state):
@@ -296,7 +335,7 @@ with gr.Blocks() as demo:
296
  output_image = gr.Image(type="pil", label='Segmentation Map')
297
 
298
  # Buttons below segmentation map (now placed under segmentation map)
299
- run_button = gr.Button("Run Automatic Segmentation")
300
  run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image)
301
 
302
  clear_button = gr.Button("Clear")
@@ -310,9 +349,12 @@ with gr.Blocks() as demo:
310
  with gr.Row(): # 水平排列
311
  with gr.Column(scale=1):
312
  input_image = gr.Image( label="Input Image", type="pil")
313
- with gr.Column(scale=1): # 第二列:分割图输出
 
314
  output_image_box = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
315
-
 
 
316
  gr.Markdown("Click the top-left and bottom-right corners of the image to select a rectangular area")
317
 
318
  input_image.select(
@@ -321,30 +363,31 @@ with gr.Blocks() as demo:
321
  outputs=[img_state_bbox, input_image]
322
  ).then(
323
  inference_box,
324
- inputs=[input_image, img_state_bbox],
325
  outputs=[output_image_box]
326
  )
327
- clear_prompt_button_box = gr.Button("Clean Prompt")
 
328
  clear_prompt_button_box.click(
329
  clean_prompts,
330
  inputs=[img_state_bbox],
331
  outputs=[img_state_bbox, input_image, output_image_box]
332
  )
333
- clear_button_box = gr.Button("Restart")
334
  clear_button_box.click(
335
  clear_everything,
336
  inputs=[img_state_bbox],
337
- outputs=[img_state_bbox, input_image, output_image_box]
338
  )
339
  input_image.clear(
340
  clear_everything,
341
  inputs=[img_state_bbox],
342
- outputs=[img_state_bbox, input_image, output_image_box]
343
  )
344
  output_image_box.clear(
345
  clear_everything,
346
  inputs=[img_state_bbox],
347
- outputs=[img_state_bbox, input_image, output_image_box]
348
  )
349
 
350
 
@@ -363,44 +406,41 @@ with gr.Blocks() as demo:
363
  with gr.Row(): # 水平排列
364
  with gr.Column(scale=1):
365
  input_image = gr.Image( label="Input Image", type="pil")
366
- with gr.Column(scale=1): # 第二列:分割图输出
 
367
  output_image_point = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
368
-
 
 
369
  input_image.select(
370
  get_points_with_draw,
371
  [input_image, img_state_points],
372
  outputs=[img_state_points, input_image]
373
  ).then(
374
  inference_point,
375
- inputs=[input_image, img_state_points],
376
  outputs=[output_image_point]
377
  )
378
- clear_prompt_button_point = gr.Button("Clean Prompt")
379
  clear_prompt_button_point.click(
380
  clean_prompts,
381
  inputs=[img_state_points],
382
  outputs=[img_state_points, input_image, output_image_point]
383
  )
384
- clear_button_point = gr.Button("Restart")
385
  clear_button_point.click(
386
  clear_everything,
387
  inputs=[img_state_points],
388
- outputs=[img_state_points, input_image, output_image_point]
389
  )
390
  input_image.clear(
391
  clear_everything,
392
  inputs=[img_state_points],
393
- outputs=[img_state_points, input_image, output_image_point]
394
  )
395
  output_image_point.clear(
396
  clear_everything,
397
  inputs=[img_state_points],
398
- outputs=[img_state_points, input_image, output_image_point]
399
  )
400
- def clear_and_set_example_point(example):
401
- clear_everything(img_state_points)
402
- return example
403
-
404
  gr.Examples(
405
  examples=examples_point,
406
  inputs=[input_image, img_state_points],
 
1
+ ## Some code was modified from Ovseg and OV-Sam.Thanks to their excellent work.
2
+ ## Ovseg Code:https://github.com/facebookresearch/ov-seg
3
+ ## OV-Sam Code:https://github.com/HarborYuan/ovsam
4
+
5
  import spaces
6
  import multiprocessing as mp
7
  import numpy as np
 
22
  import open_clip
23
  from sam2.build_sam import build_sam2
24
  from mask_adapter.modeling.meta_arch.mask_adapter_head import build_mask_adapter
25
+ from mask_adapter.data.datasets import openseg_classes
26
+
27
+ COCO_CATEGORIES_pan = openseg_classes.get_coco_categories_with_prompt_eng()
28
+ thing_classes = [k["name"] for k in COCO_CATEGORIES_pan if k["isthing"] == 1]
29
+ stuff_classes = [k["name"] for k in COCO_CATEGORIES_pan]
30
+ ADE20K_150_CATEGORIES_ = openseg_classes.get_ade20k_categories_with_prompt_eng()
31
+ ade20k_thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES_ if k["isthing"] == 1]
32
+ ade20k_stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES_]
33
+ class_names_coco_ade20k = thing_classes + stuff_classes + ade20k_thing_classes+ ade20k_stuff_classes
34
+
35
+
36
 
37
 
38
 
 
108
  @spaces.GPU
109
  @torch.no_grad()
110
  @torch.autocast(device_type="cuda", dtype=torch.float32)
111
+ def inference_point(input_img, img_state,class_names_input):
112
 
113
 
114
  mp.set_start_method("spawn", force=True)
 
121
 
122
  demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
123
 
124
+ if not class_names_input:
125
+ class_names_input = class_names_coco_ade20k
126
+
127
+ if class_names_input == class_names_coco_ade20k:
128
+ text_features = torch.from_numpy(np.load("./text_embedding/coco_ade20k_text_embedding.npy")).cuda()
129
+ _, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features)
130
+ else:
131
+ class_names_input = class_names_input.split(',')
132
+ txts = [f'a photo of {cls_name}' for cls_name in class_names_input]
133
+ text = open_clip.tokenize(txts)
134
+ text_features = clip_model.encode_text(text.cuda())
135
+ text_features /= text_features.norm(dim=-1, keepdim=True)
136
+ _, visualized_output = demo.run_on_image_with_points(img_state.img, points,text_features,class_names_input)
137
+
138
  return visualized_output
139
 
140
 
 
163
 
164
  demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
165
 
166
+ if not class_names_input:
167
+ class_names_input = class_names_coco_ade20k
168
+
169
+ if class_names_input == class_names_coco_ade20k:
170
+ text_features = torch.from_numpy(np.load("./text_embedding/coco_ade20k_text_embedding.npy")).cuda()
171
+ _, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features)
172
+ else:
173
+ class_names_input = class_names_input.split(',')
174
+ txts = [f'a photo of {cls_name}' for cls_name in class_names_input]
175
+ text = open_clip.tokenize(txts)
176
+ text_features = clip_model.encode_text(text.cuda())
177
+ text_features /= text_features.norm(dim=-1, keepdim=True)
178
+ _, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features,class_names_input)
179
+
180
  return visualized_output
181
 
182
 
 
273
 
274
  def clear_everything(img_state):
275
  img_state.clear()
276
+ return img_state, None, None, gr.Textbox(value='',lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
277
 
278
 
279
  def clean_prompts(img_state):
 
335
  output_image = gr.Image(type="pil", label='Segmentation Map')
336
 
337
  # Buttons below segmentation map (now placed under segmentation map)
338
+ run_button = gr.Button("Run Automatic Segmentation", elem_id="run_button",variant='primary')
339
  run_button.click(inference_automatic, inputs=[input_image, class_names], outputs=output_image)
340
 
341
  clear_button = gr.Button("Clear")
 
349
  with gr.Row(): # 水平排列
350
  with gr.Column(scale=1):
351
  input_image = gr.Image( label="Input Image", type="pil")
352
+ class_names_input_box = gr.Textbox(lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
353
+ with gr.Column(scale=1):
354
  output_image_box = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
355
+ clear_prompt_button_box = gr.Button("Clean Prompt")
356
+ clear_button_box = gr.Button("Restart")
357
+
358
  gr.Markdown("Click the top-left and bottom-right corners of the image to select a rectangular area")
359
 
360
  input_image.select(
 
363
  outputs=[img_state_bbox, input_image]
364
  ).then(
365
  inference_box,
366
+ inputs=[input_image, img_state_bbox,class_names_input_box],
367
  outputs=[output_image_box]
368
  )
369
+
370
+
371
  clear_prompt_button_box.click(
372
  clean_prompts,
373
  inputs=[img_state_bbox],
374
  outputs=[img_state_bbox, input_image, output_image_box]
375
  )
376
+
377
  clear_button_box.click(
378
  clear_everything,
379
  inputs=[img_state_bbox],
380
+ outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
381
  )
382
  input_image.clear(
383
  clear_everything,
384
  inputs=[img_state_bbox],
385
+ outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
386
  )
387
  output_image_box.clear(
388
  clear_everything,
389
  inputs=[img_state_bbox],
390
+ outputs=[img_state_bbox, input_image, output_image_box,class_names_input_box]
391
  )
392
 
393
 
 
406
  with gr.Row(): # 水平排列
407
  with gr.Column(scale=1):
408
  input_image = gr.Image( label="Input Image", type="pil")
409
+ class_names_input_point = gr.Textbox(lines=1, placeholder=class_names_coco_ade20k, label='Class Names')
410
+ with gr.Column(scale=1):
411
  output_image_point = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
412
+ clear_prompt_button_point = gr.Button("Clean Prompt")
413
+ clear_button_point = gr.Button("Restart")
414
+
415
  input_image.select(
416
  get_points_with_draw,
417
  [input_image, img_state_points],
418
  outputs=[img_state_points, input_image]
419
  ).then(
420
  inference_point,
421
+ inputs=[input_image, img_state_points,class_names_input_point],
422
  outputs=[output_image_point]
423
  )
 
424
  clear_prompt_button_point.click(
425
  clean_prompts,
426
  inputs=[img_state_points],
427
  outputs=[img_state_points, input_image, output_image_point]
428
  )
 
429
  clear_button_point.click(
430
  clear_everything,
431
  inputs=[img_state_points],
432
+ outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
433
  )
434
  input_image.clear(
435
  clear_everything,
436
  inputs=[img_state_points],
437
+ outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
438
  )
439
  output_image_point.clear(
440
  clear_everything,
441
  inputs=[img_state_points],
442
+ outputs=[img_state_points, input_image, output_image_point,class_names_input_point]
443
  )
 
 
 
 
444
  gr.Examples(
445
  examples=examples_point,
446
  inputs=[input_image, img_state_points],