Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -125,6 +125,33 @@ sam2_model = None
|
|
125 |
clip_model = None
|
126 |
mask_adapter = None
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
def get_points_with_draw(image, img_state, evt: gr.SelectData):
|
129 |
label = 'Add Mask'
|
130 |
|
@@ -140,6 +167,41 @@ def get_points_with_draw(image, img_state, evt: gr.SelectData):
|
|
140 |
fill=point_color,
|
141 |
)
|
142 |
return img_state, image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
|
144 |
cfg = setup_cfg(cfg)
|
145 |
global sam2_model, clip_model, mask_adapter
|
@@ -234,6 +296,39 @@ with gr.Blocks() as demo:
|
|
234 |
with gr.Row():
|
235 |
gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
|
236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
with gr.TabItem("Point Mode"):
|
238 |
img_state_points = gr.State(value=IMGState())
|
239 |
with gr.Row(): # 水平排列
|
|
|
125 |
clip_model = None
|
126 |
mask_adapter = None
|
127 |
|
128 |
+
@spaces.GPU
|
129 |
+
@torch.no_grad()
|
130 |
+
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
131 |
+
def inference_box(input_img, img_state,):
|
132 |
+
|
133 |
+
|
134 |
+
mp.set_start_method("spawn", force=True)
|
135 |
+
|
136 |
+
box_points = img_state.selected_bboxes
|
137 |
+
bbox = (
|
138 |
+
min(box_points[0][0], box_points[1][0]),
|
139 |
+
min(box_points[0][1], box_points[1][1]),
|
140 |
+
max(box_points[0][0], box_points[1][0]),
|
141 |
+
max(box_points[0][1], box_points[1][1]),
|
142 |
+
)
|
143 |
+
bbox = np.array(bbox)
|
144 |
+
config_file = './configs/ground-truth-warmup/mask-adapter/mask_adapter_convnext_large_cocopan_eval_ade20k.yaml'
|
145 |
+
cfg = setup_cfg(config_file)
|
146 |
+
|
147 |
+
demo = SAMPointVisualizationDemo(cfg, 0.8, sam2_model, clip_model,mask_adapter)
|
148 |
+
|
149 |
+
text_features = torch.from_numpy(np.load("./text_embedding/lvis_coco_text_embedding.npy")).cuda()
|
150 |
+
_, visualized_output = demo.run_on_image_with_boxes(img_state.img, bbox,text_features)
|
151 |
+
return visualized_output
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
def get_points_with_draw(image, img_state, evt: gr.SelectData):
|
156 |
label = 'Add Mask'
|
157 |
|
|
|
167 |
fill=point_color,
|
168 |
)
|
169 |
return img_state, image
|
170 |
+
|
171 |
+
def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
|
172 |
+
x, y = evt.index[0], evt.index[1]
|
173 |
+
point_radius, point_color, box_outline = 5, (237, 34, 13), 2
|
174 |
+
box_color = (237, 34, 13)
|
175 |
+
|
176 |
+
if len(img_state.selected_bboxes) in [0, 1]:
|
177 |
+
img_state.selected_bboxes.append([x, y])
|
178 |
+
elif len(img_state.selected_bboxes) == 2:
|
179 |
+
img_state.selected_bboxes = [[x, y]]
|
180 |
+
image = Image.fromarray(img_state.img)
|
181 |
+
else:
|
182 |
+
raise ValueError(f"Cannot be {len(img_state.selected_bboxes)}")
|
183 |
+
img_state.set_img(np.array(image), None)
|
184 |
+
|
185 |
+
draw = ImageDraw.Draw(image)
|
186 |
+
draw.ellipse(
|
187 |
+
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
|
188 |
+
fill=point_color,
|
189 |
+
)
|
190 |
+
|
191 |
+
if len(img_state.selected_bboxes) == 2:
|
192 |
+
box_points = img_state.selected_bboxes
|
193 |
+
bbox = (min(box_points[0][0], box_points[1][0]),
|
194 |
+
min(box_points[0][1], box_points[1][1]),
|
195 |
+
max(box_points[0][0], box_points[1][0]),
|
196 |
+
max(box_points[0][1], box_points[1][1]),
|
197 |
+
)
|
198 |
+
draw.rectangle(
|
199 |
+
bbox,
|
200 |
+
outline=box_color,
|
201 |
+
width=box_outline
|
202 |
+
)
|
203 |
+
return img_state, image
|
204 |
+
|
205 |
def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
|
206 |
cfg = setup_cfg(cfg)
|
207 |
global sam2_model, clip_model, mask_adapter
|
|
|
296 |
with gr.Row():
|
297 |
gr.Examples(examples=examples, inputs=[input_image, class_names], outputs=output_image)
|
298 |
|
299 |
+
with gr.TabItem("Box Mode"):
|
300 |
+
img_state_bbox = gr.State(value=IMGState())
|
301 |
+
with gr.Row(): # 水平排列
|
302 |
+
with gr.Column(scale=1):
|
303 |
+
input_image = gr.Image( label="Input Image", type="pil")
|
304 |
+
with gr.Column(scale=1): # 第二列:分割图输出
|
305 |
+
output_image_box = gr.Image(type="pil", label='Segmentation Map',interactive=False) # 输出分割图
|
306 |
+
|
307 |
+
input_image.select(
|
308 |
+
get_bbox_with_draw,
|
309 |
+
[input_image, img_state_bbox],
|
310 |
+
outputs=[img_state_bbox, input_image]
|
311 |
+
).then(
|
312 |
+
inference_box,
|
313 |
+
inputs=[input_image, img_state_bbox],
|
314 |
+
outputs=[output_image_box]
|
315 |
+
)
|
316 |
+
clear_prompt_button_box = gr.Button("Clean Prompt")
|
317 |
+
clear_prompt_button_box.click(
|
318 |
+
clean_prompts,
|
319 |
+
inputs=[img_state_bbox],
|
320 |
+
outputs=[img_state_bbox, input_image, output_image_box]
|
321 |
+
)
|
322 |
+
clear_button_box = gr.Button("Restart")
|
323 |
+
clear_button_box.click(
|
324 |
+
clear_everything,
|
325 |
+
inputs=[img_state_bbox],
|
326 |
+
outputs=[img_state_bbox, input_image, output_image_box]
|
327 |
+
)
|
328 |
+
|
329 |
+
with gr.Row():
|
330 |
+
gr.Examples(examples=examples_point, inputs=input_image, outputs=output_image_box,examples_per_page=5)
|
331 |
+
|
332 |
with gr.TabItem("Point Mode"):
|
333 |
img_state_points = gr.State(value=IMGState())
|
334 |
with gr.Row(): # 水平排列
|