wondervictor commited on
Commit
73550bc
·
verified ·
1 Parent(s): 7f8348b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -22
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import spaces
2
  import multiprocessing as mp
3
  import numpy as np
4
- from PIL import Image
5
  import torch
6
  try:
7
  import detectron2
@@ -32,6 +32,49 @@ def setup_cfg(config_file):
32
  cfg.freeze()
33
  return cfg
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @spaces.GPU
36
  @torch.no_grad()
37
  @torch.autocast(device_type="cuda", dtype=torch.float32)
@@ -60,13 +103,12 @@ def inference_automatic(input_img, class_names):
60
  @spaces.GPU
61
  @torch.no_grad()
62
  @torch.autocast(device_type="cuda", dtype=torch.float32)
63
- def inference_point(input_img, evt: gr.SelectData,):
64
 
65
 
66
  mp.set_start_method("spawn", force=True)
67
 
68
- x, y = evt.index[0], evt.index[1]
69
- points = [[x, y]]
70
  print(f"Selected point: {points}")
71
 
72
  config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
@@ -74,10 +116,8 @@ def inference_point(input_img, evt: gr.SelectData,):
74
 
75
  demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
76
 
77
- img = read_image(input_img, format="BGR")
78
-
79
 
80
- _, visualized_output = demo.run_on_image_with_points(img, points)
81
  return visualized_output
82
 
83
 
@@ -85,6 +125,21 @@ sam2_model = None
85
  clip_model = None
86
  mask_adapter = None
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
89
  cfg = setup_cfg(cfg)
90
  global sam2_model, clip_model, mask_adapter
@@ -107,6 +162,16 @@ def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
107
  mask_adapter.load_state_dict(adapter_state_dict)
108
  print("Mask Adapter model initialized.")
109
 
 
 
 
 
 
 
 
 
 
 
110
  model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
111
  sam_path = './sam2.1_hiera_large.pt'
112
  adapter_pth = './model_0279999_with_sem_new.pth'
@@ -162,21 +227,35 @@ with gr.Blocks() as demo:
162
  gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
163
 
164
  with gr.TabItem("Point Mode"):
165
- with gr.Row():
166
- with gr.Column():
167
- def init_state():
168
- return []
169
- input_image = gr.Image(type='filepath', label="Upload Image", interactive=True)
170
- points_input = gr.State(value=init_state())
171
-
172
- with gr.Column():
173
- output_image_point = gr.Image(type="pil", label='Segmentation Map')
174
-
175
- input_image.select(inference_point, inputs=[input_image], outputs=output_image_point)
176
-
177
- clear_button_point = gr.Button("Clear Segmentation Map")
178
- clear_button_point.click(lambda: None, inputs=None, outputs=output_image_point)
179
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
 
 
1
  import spaces
2
  import multiprocessing as mp
3
  import numpy as np
4
+ from PIL import Image,ImageDraw
5
  import torch
6
  try:
7
  import detectron2
 
32
  cfg.freeze()
33
  return cfg
34
 
35
+ class IMGState:
36
+ def __init__(self):
37
+ self.img = None
38
+ self.img_feat = None
39
+ self.selected_points = []
40
+ self.selected_points_labels = []
41
+ self.selected_bboxes = []
42
+
43
+ self.available_to_set = True
44
+
45
+ def set_img(self, img, img_feat):
46
+ self.img = img
47
+ self.img_feat = img_feat
48
+
49
+ self.available_to_set = False
50
+
51
+ def clear(self):
52
+ self.img = None
53
+ self.img_feat = None
54
+ self.selected_points = []
55
+ self.selected_points_labels = []
56
+ self.selected_bboxes = []
57
+
58
+ self.available_to_set = True
59
+
60
+ def clean(self):
61
+ self.selected_points = []
62
+ self.selected_points_labels = []
63
+ self.selected_bboxes = []
64
+
65
+ def to_device(self, device=torch.device("cuda")):
66
+ if self.img_feat is not None:
67
+ for k in self.img_feat:
68
+ if isinstance(self.img_feat[k], torch.Tensor):
69
+ self.img_feat[k] = self.img_feat[k].to(device)
70
+ elif isinstance(self.img_feat[k], tuple):
71
+ self.img_feat[k] = tuple(v.to(device) for v in self.img_feat[k])
72
+
73
+ @property
74
+ def available(self):
75
+ return self.available_to_set
76
+
77
+
78
  @spaces.GPU
79
  @torch.no_grad()
80
  @torch.autocast(device_type="cuda", dtype=torch.float32)
 
103
  @spaces.GPU
104
  @torch.no_grad()
105
  @torch.autocast(device_type="cuda", dtype=torch.float32)
106
+ def inference_point(input_img, img_state,):
107
 
108
 
109
  mp.set_start_method("spawn", force=True)
110
 
111
+ points = img_state.selected_points
 
112
  print(f"Selected point: {points}")
113
 
114
  config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
 
116
 
117
  demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
118
 
 
 
119
 
120
+ _, visualized_output = demo.run_on_image_with_points(img_state.img, points)
121
  return visualized_output
122
 
123
 
 
125
  clip_model = None
126
  mask_adapter = None
127
 
128
+ def get_points_with_draw(image, img_state, evt: gr.SelectData):
129
+ label = 'Add Mask'
130
+
131
+ x, y = evt.index[0], evt.index[1]
132
+ point_radius, point_color = 10, (97, 217, 54) if label == "Add Mask" else (237, 34, 13)
133
+
134
+ img_state.selected_points.append([x, y])
135
+ img_state.selected_points_labels.append(1 if label == "Add Mask" else 0)
136
+ img_state.set_img(np.array(image), None)
137
+ draw = ImageDraw.Draw(image)
138
+ draw.ellipse(
139
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
140
+ fill=point_color,
141
+ )
142
+ return img_state, image
143
  def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
144
  cfg = setup_cfg(cfg)
145
  global sam2_model, clip_model, mask_adapter
 
162
  mask_adapter.load_state_dict(adapter_state_dict)
163
  print("Mask Adapter model initialized.")
164
 
165
+ def clear_everything(img_state):
166
+ img_state.clear()
167
+ return img_state, None, None
168
+
169
+
170
+ def clean_prompts(img_state):
171
+ img_state.clean()
172
+ return img_state, Image.fromarray(img_state.img), None
173
+
174
+ # 初始化配置和模型
175
  model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
176
  sam_path = './sam2.1_hiera_large.pt'
177
  adapter_pth = './model_0279999_with_sem_new.pth'
 
227
  gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
228
 
229
  with gr.TabItem("Point Mode"):
230
+ img_state_points = gr.State(value=IMGState())
231
+ with gr.Row(): # 水平排列
232
+ with gr.Column(scale=1):
233
+ input_image = gr.Image( label="Input Image", type="pil")
234
+ with gr.Column(scale=1): # 第二列:分割图输出
235
+ output_image_point = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
236
+
237
+ input_image.select(
238
+ get_points_with_draw,
239
+ [input_image, img_state_points],
240
+ outputs=[img_state_points, input_image]
241
+ ).then(
242
+ inference_point,
243
+ inputs=[input_image, img_state_points],
244
+ outputs=[output_image_point]
245
+ )
246
+ clear_prompt_button_point = gr.Button("Clean Prompt")
247
+ clear_prompt_button_point.click(
248
+ clean_prompts,
249
+ inputs=[img_state_points],
250
+ outputs=[img_state_points, input_image, output_image_point]
251
+ )
252
+ clear_button_point = gr.Button("Restart")
253
+ clear_button_point.click(
254
+ clear_everything,
255
+ inputs=[img_state_points],
256
+ outputs=[img_state_points, input_image, output_image_point]
257
+ )
258
+
259
 
260
 
261