Spaces:
Runtime error
Runtime error
Haobo Yuan
commited on
Commit
·
1cf72c0
1
Parent(s):
8e42464
img_state goes back
Browse files
main.py
CHANGED
@@ -106,7 +106,7 @@ def get_points_with_draw(image, img_state, evt: gr.SelectData):
|
|
106 |
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
|
107 |
fill=point_color,
|
108 |
)
|
109 |
-
return image
|
110 |
|
111 |
|
112 |
def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
|
@@ -142,7 +142,7 @@ def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
|
|
142 |
outline=box_color,
|
143 |
width=box_outline
|
144 |
)
|
145 |
-
return image
|
146 |
|
147 |
|
148 |
def segment_with_points(
|
@@ -192,7 +192,7 @@ def segment_with_points(
|
|
192 |
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
|
193 |
|
194 |
output_img = Image.fromarray(output_img)
|
195 |
-
return image, output_img, cls_info
|
196 |
|
197 |
|
198 |
def segment_with_bbox(
|
@@ -251,7 +251,7 @@ def segment_with_bbox(
|
|
251 |
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
|
252 |
|
253 |
output_img = Image.fromarray(output_img)
|
254 |
-
return image, output_img, cls_info
|
255 |
|
256 |
|
257 |
def extract_img_feat(img, img_state):
|
@@ -278,12 +278,12 @@ def extract_img_feat(img, img_state):
|
|
278 |
return None, None, "CUDA OOM, please try again later."
|
279 |
else:
|
280 |
raise
|
281 |
-
return img, None, "Please try to click something."
|
282 |
|
283 |
|
284 |
def clear_everything(img_state):
|
285 |
img_state.clear()
|
286 |
-
return None, None, "Please try to click something."
|
287 |
|
288 |
|
289 |
def clean_prompts(img_state):
|
@@ -291,7 +291,7 @@ def clean_prompts(img_state):
|
|
291 |
if img_state.img is None:
|
292 |
img_state.clear()
|
293 |
return None, None, "Please try to click something."
|
294 |
-
return Image.fromarray(img_state.img), None, "Please try to click something."
|
295 |
|
296 |
|
297 |
def register_point_mode():
|
@@ -325,7 +325,7 @@ def register_point_mode():
|
|
325 |
gr.Examples(
|
326 |
examples=examples,
|
327 |
inputs=[cond_img_p, img_state_points],
|
328 |
-
outputs=[cond_img_p, segm_img_p, cls_info],
|
329 |
examples_per_page=12,
|
330 |
fn=extract_img_feat,
|
331 |
run_on_click=True
|
@@ -355,7 +355,7 @@ def register_point_mode():
|
|
355 |
gr.Examples(
|
356 |
examples=examples,
|
357 |
inputs=[cond_img_bbox, img_state_bbox],
|
358 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox],
|
359 |
examples_per_page=12,
|
360 |
fn=extract_img_feat,
|
361 |
run_on_click=True
|
@@ -365,76 +365,76 @@ def register_point_mode():
|
|
365 |
cond_img_p.upload(
|
366 |
extract_img_feat,
|
367 |
[cond_img_p, img_state_points],
|
368 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
369 |
)
|
370 |
cond_img_bbox.upload(
|
371 |
extract_img_feat,
|
372 |
[cond_img_bbox, img_state_bbox],
|
373 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info]
|
374 |
)
|
375 |
|
376 |
# get user added points
|
377 |
cond_img_p.select(
|
378 |
get_points_with_draw,
|
379 |
[cond_img_p, img_state_points],
|
380 |
-
cond_img_p
|
381 |
).then(
|
382 |
segment_with_points,
|
383 |
inputs=[cond_img_p, img_state_points],
|
384 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
385 |
)
|
386 |
cond_img_bbox.select(
|
387 |
get_bbox_with_draw,
|
388 |
[cond_img_bbox, img_state_bbox],
|
389 |
-
cond_img_bbox
|
390 |
).then(
|
391 |
segment_with_bbox,
|
392 |
inputs=[cond_img_bbox, img_state_bbox],
|
393 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
394 |
)
|
395 |
|
396 |
# clean prompts
|
397 |
clean_btn_p.click(
|
398 |
clean_prompts,
|
399 |
inputs=[img_state_points],
|
400 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
401 |
)
|
402 |
clean_btn_bbox.click(
|
403 |
clean_prompts,
|
404 |
inputs=[img_state_bbox],
|
405 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
406 |
)
|
407 |
|
408 |
# clear
|
409 |
clear_btn_p.click(
|
410 |
clear_everything,
|
411 |
inputs=[img_state_points],
|
412 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
413 |
)
|
414 |
cond_img_p.clear(
|
415 |
clear_everything,
|
416 |
inputs=[img_state_points],
|
417 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
418 |
)
|
419 |
segm_img_p.clear(
|
420 |
clear_everything,
|
421 |
inputs=[img_state_points],
|
422 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
423 |
)
|
424 |
clear_btn_bbox.click(
|
425 |
clear_everything,
|
426 |
inputs=[img_state_bbox],
|
427 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
428 |
)
|
429 |
cond_img_bbox.clear(
|
430 |
clear_everything,
|
431 |
inputs=[img_state_bbox],
|
432 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
433 |
)
|
434 |
segm_img_bbox.clear(
|
435 |
clear_everything,
|
436 |
inputs=[img_state_bbox],
|
437 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
438 |
)
|
439 |
|
440 |
|
|
|
106 |
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
|
107 |
fill=point_color,
|
108 |
)
|
109 |
+
return img_state, image
|
110 |
|
111 |
|
112 |
def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
|
|
|
142 |
outline=box_color,
|
143 |
width=box_outline
|
144 |
)
|
145 |
+
return img_state, image
|
146 |
|
147 |
|
148 |
def segment_with_points(
|
|
|
192 |
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
|
193 |
|
194 |
output_img = Image.fromarray(output_img)
|
195 |
+
return img_state, image, output_img, cls_info
|
196 |
|
197 |
|
198 |
def segment_with_bbox(
|
|
|
251 |
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
|
252 |
|
253 |
output_img = Image.fromarray(output_img)
|
254 |
+
return img_state, image, output_img, cls_info
|
255 |
|
256 |
|
257 |
def extract_img_feat(img, img_state):
|
|
|
278 |
return None, None, "CUDA OOM, please try again later."
|
279 |
else:
|
280 |
raise
|
281 |
+
return img_state, img, None, "Please try to click something."
|
282 |
|
283 |
|
284 |
def clear_everything(img_state):
|
285 |
img_state.clear()
|
286 |
+
return img_state, None, None, "Please try to click something."
|
287 |
|
288 |
|
289 |
def clean_prompts(img_state):
|
|
|
291 |
if img_state.img is None:
|
292 |
img_state.clear()
|
293 |
return None, None, "Please try to click something."
|
294 |
+
return img_state, Image.fromarray(img_state.img), None, "Please try to click something."
|
295 |
|
296 |
|
297 |
def register_point_mode():
|
|
|
325 |
gr.Examples(
|
326 |
examples=examples,
|
327 |
inputs=[cond_img_p, img_state_points],
|
328 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info],
|
329 |
examples_per_page=12,
|
330 |
fn=extract_img_feat,
|
331 |
run_on_click=True
|
|
|
355 |
gr.Examples(
|
356 |
examples=examples,
|
357 |
inputs=[cond_img_bbox, img_state_bbox],
|
358 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox],
|
359 |
examples_per_page=12,
|
360 |
fn=extract_img_feat,
|
361 |
run_on_click=True
|
|
|
365 |
cond_img_p.upload(
|
366 |
extract_img_feat,
|
367 |
[cond_img_p, img_state_points],
|
368 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
369 |
)
|
370 |
cond_img_bbox.upload(
|
371 |
extract_img_feat,
|
372 |
[cond_img_bbox, img_state_bbox],
|
373 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info]
|
374 |
)
|
375 |
|
376 |
# get user added points
|
377 |
cond_img_p.select(
|
378 |
get_points_with_draw,
|
379 |
[cond_img_p, img_state_points],
|
380 |
+
outputs=[img_state_points, cond_img_p]
|
381 |
).then(
|
382 |
segment_with_points,
|
383 |
inputs=[cond_img_p, img_state_points],
|
384 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
385 |
)
|
386 |
cond_img_bbox.select(
|
387 |
get_bbox_with_draw,
|
388 |
[cond_img_bbox, img_state_bbox],
|
389 |
+
outputs=[img_state_bbox, cond_img_bbox]
|
390 |
).then(
|
391 |
segment_with_bbox,
|
392 |
inputs=[cond_img_bbox, img_state_bbox],
|
393 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
394 |
)
|
395 |
|
396 |
# clean prompts
|
397 |
clean_btn_p.click(
|
398 |
clean_prompts,
|
399 |
inputs=[img_state_points],
|
400 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
401 |
)
|
402 |
clean_btn_bbox.click(
|
403 |
clean_prompts,
|
404 |
inputs=[img_state_bbox],
|
405 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
406 |
)
|
407 |
|
408 |
# clear
|
409 |
clear_btn_p.click(
|
410 |
clear_everything,
|
411 |
inputs=[img_state_points],
|
412 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
413 |
)
|
414 |
cond_img_p.clear(
|
415 |
clear_everything,
|
416 |
inputs=[img_state_points],
|
417 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
418 |
)
|
419 |
segm_img_p.clear(
|
420 |
clear_everything,
|
421 |
inputs=[img_state_points],
|
422 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
423 |
)
|
424 |
clear_btn_bbox.click(
|
425 |
clear_everything,
|
426 |
inputs=[img_state_bbox],
|
427 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
428 |
)
|
429 |
cond_img_bbox.clear(
|
430 |
clear_everything,
|
431 |
inputs=[img_state_bbox],
|
432 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
433 |
)
|
434 |
segm_img_bbox.clear(
|
435 |
clear_everything,
|
436 |
inputs=[img_state_bbox],
|
437 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
438 |
)
|
439 |
|
440 |
|