wondervictor commited on
Commit
c570f20
·
verified ·
1 Parent(s): 45caf8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py CHANGED
@@ -125,6 +125,33 @@ sam2_model = None
125
  clip_model = None
126
  mask_adapter = None
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def get_points_with_draw(image, img_state, evt: gr.SelectData):
129
  label = 'Add Mask'
130
 
@@ -140,6 +167,41 @@ def get_points_with_draw(image, img_state, evt: gr.SelectData):
140
  fill=point_color,
141
  )
142
  return img_state, image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
144
  cfg = setup_cfg(cfg)
145
  global sam2_model, clip_model, mask_adapter
@@ -234,6 +296,39 @@ with gr.Blocks() as demo:
234
  with gr.Row():
235
  gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  with gr.TabItem("Point Mode"):
238
  img_state_points = gr.State(value=IMGState())
239
  with gr.Row(): # 水平排列
 
125
  clip_model = None
126
  mask_adapter = None
127
 
128
+ @spaces.GPU
129
+ @torch.no_grad()
130
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
131
+ def inference_box(input_img, img_state,):
132
+
133
+
134
+ mp.set_start_method("spawn", force=True)
135
+
136
+ box_points = img_state.selected_bboxes
137
+ bbox = (
138
+ min(box_points[0][0], box_points[1][0]),
139
+ min(box_points[0][1], box_points[1][1]),
140
+ max(box_points[0][0], box_points[1][0]),
141
+ max(box_points[0][1], box_points[1][1]),
142
+ )
143
+ bbox = np.array(bbox)
144
+ config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
145
+ cfg = setup_cfg(config_file)
146
+
147
+ demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
148
+
149
+ text_features = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy")).cuda()
150
+ _, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features)
151
+ return visualized_output
152
+
153
+
154
+
155
  def get_points_with_draw(image, img_state, evt: gr.SelectData):
156
  label = 'Add Mask'
157
 
 
167
  fill=point_color,
168
  )
169
  return img_state, image
170
+
171
+ def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
172
+ x, y = evt.index[0], evt.index[1]
173
+ point_radius, point_color, box_outline = 5, (237, 34, 13), 2
174
+ box_color = (237, 34, 13)
175
+
176
+ if len(img_state.selected_bboxes) in [0, 1]:
177
+ img_state.selected_bboxes.append([x, y])
178
+ elif len(img_state.selected_bboxes) == 2:
179
+ img_state.selected_bboxes = [[x, y]]
180
+ image = Image.fromarray(img_state.img)
181
+ else:
182
+ raise ValueError(f"Cannot be {len(img_state.selected_bboxes)}")
183
+ img_state.set_img(np.array(image), None)
184
+
185
+ draw = ImageDraw.Draw(image)
186
+ draw.ellipse(
187
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
188
+ fill=point_color,
189
+ )
190
+
191
+ if len(img_state.selected_bboxes) == 2:
192
+ box_points = img_state.selected_bboxes
193
+ bbox = (min(box_points[0][0], box_points[1][0]),
194
+ min(box_points[0][1], box_points[1][1]),
195
+ max(box_points[0][0], box_points[1][0]),
196
+ max(box_points[0][1], box_points[1][1]),
197
+ )
198
+ draw.rectangle(
199
+ bbox,
200
+ outline=box_color,
201
+ width=box_outline
202
+ )
203
+ return img_state, image
204
+
205
  def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
206
  cfg = setup_cfg(cfg)
207
  global sam2_model, clip_model, mask_adapter
 
296
  with gr.Row():
297
  gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
298
 
299
+ with gr.TabItem("Box Mode"):
300
+ img_state_bbox = gr.State(value=IMGState())
301
+ with gr.Row(): # 水平排列
302
+ with gr.Column(scale=1):
303
+ input_image = gr.Image( label="Input Image", type="pil")
304
+ with gr.Column(scale=1): # 第二列:分割图输出
305
+ output_image_box = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
306
+
307
+ input_image.select(
308
+ get_bbox_with_draw,
309
+ [input_image, img_state_bbox],
310
+ outputs=[img_state_bbox, input_image]
311
+ ).then(
312
+ inference_box,
313
+ inputs=[input_image, img_state_bbox],
314
+ outputs=[output_image_box]
315
+ )
316
+ clear_prompt_button_box = gr.Button("Clean Prompt")
317
+ clear_prompt_button_box.click(
318
+ clean_prompts,
319
+ inputs=[img_state_bbox],
320
+ outputs=[img_state_bbox, input_image, output_image_box]
321
+ )
322
+ clear_button_box = gr.Button("Restart")
323
+ clear_button_box.click(
324
+ clear_everything,
325
+ inputs=[img_state_bbox],
326
+ outputs=[img_state_bbox, input_image, output_image_box]
327
+ )
328
+
329
+ with gr.Row():
330
+ gr.Examples(examples=examples_point, inputs=input_image, outputs=output_image_box,examples_per_page=5)
331
+
332
  with gr.TabItem("Point Mode"):
333
  img_state_points = gr.State(value=IMGState())
334
  with gr.Row(): # 水平排列