Spaces:
Runtime error
Runtime error
Haobo Yuan
commited on
Commit
·
1cf72c0
1
Parent(s):
8e42464
img_state goes back
Browse files
main.py
CHANGED
|
@@ -106,7 +106,7 @@ def get_points_with_draw(image, img_state, evt: gr.SelectData):
|
|
| 106 |
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
|
| 107 |
fill=point_color,
|
| 108 |
)
|
| 109 |
-
return image
|
| 110 |
|
| 111 |
|
| 112 |
def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
|
|
@@ -142,7 +142,7 @@ def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
|
|
| 142 |
outline=box_color,
|
| 143 |
width=box_outline
|
| 144 |
)
|
| 145 |
-
return image
|
| 146 |
|
| 147 |
|
| 148 |
def segment_with_points(
|
|
@@ -192,7 +192,7 @@ def segment_with_points(
|
|
| 192 |
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
|
| 193 |
|
| 194 |
output_img = Image.fromarray(output_img)
|
| 195 |
-
return image, output_img, cls_info
|
| 196 |
|
| 197 |
|
| 198 |
def segment_with_bbox(
|
|
@@ -251,7 +251,7 @@ def segment_with_bbox(
|
|
| 251 |
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
|
| 252 |
|
| 253 |
output_img = Image.fromarray(output_img)
|
| 254 |
-
return image, output_img, cls_info
|
| 255 |
|
| 256 |
|
| 257 |
def extract_img_feat(img, img_state):
|
|
@@ -278,12 +278,12 @@ def extract_img_feat(img, img_state):
|
|
| 278 |
return None, None, "CUDA OOM, please try again later."
|
| 279 |
else:
|
| 280 |
raise
|
| 281 |
-
return img, None, "Please try to click something."
|
| 282 |
|
| 283 |
|
| 284 |
def clear_everything(img_state):
|
| 285 |
img_state.clear()
|
| 286 |
-
return None, None, "Please try to click something."
|
| 287 |
|
| 288 |
|
| 289 |
def clean_prompts(img_state):
|
|
@@ -291,7 +291,7 @@ def clean_prompts(img_state):
|
|
| 291 |
if img_state.img is None:
|
| 292 |
img_state.clear()
|
| 293 |
return None, None, "Please try to click something."
|
| 294 |
-
return Image.fromarray(img_state.img), None, "Please try to click something."
|
| 295 |
|
| 296 |
|
| 297 |
def register_point_mode():
|
|
@@ -325,7 +325,7 @@ def register_point_mode():
|
|
| 325 |
gr.Examples(
|
| 326 |
examples=examples,
|
| 327 |
inputs=[cond_img_p, img_state_points],
|
| 328 |
-
outputs=[cond_img_p, segm_img_p, cls_info],
|
| 329 |
examples_per_page=12,
|
| 330 |
fn=extract_img_feat,
|
| 331 |
run_on_click=True
|
|
@@ -355,7 +355,7 @@ def register_point_mode():
|
|
| 355 |
gr.Examples(
|
| 356 |
examples=examples,
|
| 357 |
inputs=[cond_img_bbox, img_state_bbox],
|
| 358 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox],
|
| 359 |
examples_per_page=12,
|
| 360 |
fn=extract_img_feat,
|
| 361 |
run_on_click=True
|
|
@@ -365,76 +365,76 @@ def register_point_mode():
|
|
| 365 |
cond_img_p.upload(
|
| 366 |
extract_img_feat,
|
| 367 |
[cond_img_p, img_state_points],
|
| 368 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
| 369 |
)
|
| 370 |
cond_img_bbox.upload(
|
| 371 |
extract_img_feat,
|
| 372 |
[cond_img_bbox, img_state_bbox],
|
| 373 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info]
|
| 374 |
)
|
| 375 |
|
| 376 |
# get user added points
|
| 377 |
cond_img_p.select(
|
| 378 |
get_points_with_draw,
|
| 379 |
[cond_img_p, img_state_points],
|
| 380 |
-
cond_img_p
|
| 381 |
).then(
|
| 382 |
segment_with_points,
|
| 383 |
inputs=[cond_img_p, img_state_points],
|
| 384 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
| 385 |
)
|
| 386 |
cond_img_bbox.select(
|
| 387 |
get_bbox_with_draw,
|
| 388 |
[cond_img_bbox, img_state_bbox],
|
| 389 |
-
cond_img_bbox
|
| 390 |
).then(
|
| 391 |
segment_with_bbox,
|
| 392 |
inputs=[cond_img_bbox, img_state_bbox],
|
| 393 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
| 394 |
)
|
| 395 |
|
| 396 |
# clean prompts
|
| 397 |
clean_btn_p.click(
|
| 398 |
clean_prompts,
|
| 399 |
inputs=[img_state_points],
|
| 400 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
| 401 |
)
|
| 402 |
clean_btn_bbox.click(
|
| 403 |
clean_prompts,
|
| 404 |
inputs=[img_state_bbox],
|
| 405 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
| 406 |
)
|
| 407 |
|
| 408 |
# clear
|
| 409 |
clear_btn_p.click(
|
| 410 |
clear_everything,
|
| 411 |
inputs=[img_state_points],
|
| 412 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
| 413 |
)
|
| 414 |
cond_img_p.clear(
|
| 415 |
clear_everything,
|
| 416 |
inputs=[img_state_points],
|
| 417 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
| 418 |
)
|
| 419 |
segm_img_p.clear(
|
| 420 |
clear_everything,
|
| 421 |
inputs=[img_state_points],
|
| 422 |
-
outputs=[cond_img_p, segm_img_p, cls_info]
|
| 423 |
)
|
| 424 |
clear_btn_bbox.click(
|
| 425 |
clear_everything,
|
| 426 |
inputs=[img_state_bbox],
|
| 427 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
| 428 |
)
|
| 429 |
cond_img_bbox.clear(
|
| 430 |
clear_everything,
|
| 431 |
inputs=[img_state_bbox],
|
| 432 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
| 433 |
)
|
| 434 |
segm_img_bbox.clear(
|
| 435 |
clear_everything,
|
| 436 |
inputs=[img_state_bbox],
|
| 437 |
-
outputs=[cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
| 438 |
)
|
| 439 |
|
| 440 |
|
|
|
|
| 106 |
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
|
| 107 |
fill=point_color,
|
| 108 |
)
|
| 109 |
+
return img_state, image
|
| 110 |
|
| 111 |
|
| 112 |
def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
|
|
|
|
| 142 |
outline=box_color,
|
| 143 |
width=box_outline
|
| 144 |
)
|
| 145 |
+
return img_state, image
|
| 146 |
|
| 147 |
|
| 148 |
def segment_with_points(
|
|
|
|
| 192 |
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
|
| 193 |
|
| 194 |
output_img = Image.fromarray(output_img)
|
| 195 |
+
return img_state, image, output_img, cls_info
|
| 196 |
|
| 197 |
|
| 198 |
def segment_with_bbox(
|
|
|
|
| 251 |
output_img = (output_img * 0.7 + color * 0.3).astype(np.uint8)
|
| 252 |
|
| 253 |
output_img = Image.fromarray(output_img)
|
| 254 |
+
return img_state, image, output_img, cls_info
|
| 255 |
|
| 256 |
|
| 257 |
def extract_img_feat(img, img_state):
|
|
|
|
| 278 |
return None, None, "CUDA OOM, please try again later."
|
| 279 |
else:
|
| 280 |
raise
|
| 281 |
+
return img_state, img, None, "Please try to click something."
|
| 282 |
|
| 283 |
|
| 284 |
def clear_everything(img_state):
|
| 285 |
img_state.clear()
|
| 286 |
+
return img_state, None, None, "Please try to click something."
|
| 287 |
|
| 288 |
|
| 289 |
def clean_prompts(img_state):
|
|
|
|
| 291 |
if img_state.img is None:
|
| 292 |
img_state.clear()
|
| 293 |
return None, None, "Please try to click something."
|
| 294 |
+
return img_state, Image.fromarray(img_state.img), None, "Please try to click something."
|
| 295 |
|
| 296 |
|
| 297 |
def register_point_mode():
|
|
|
|
| 325 |
gr.Examples(
|
| 326 |
examples=examples,
|
| 327 |
inputs=[cond_img_p, img_state_points],
|
| 328 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info],
|
| 329 |
examples_per_page=12,
|
| 330 |
fn=extract_img_feat,
|
| 331 |
run_on_click=True
|
|
|
|
| 355 |
gr.Examples(
|
| 356 |
examples=examples,
|
| 357 |
inputs=[cond_img_bbox, img_state_bbox],
|
| 358 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox],
|
| 359 |
examples_per_page=12,
|
| 360 |
fn=extract_img_feat,
|
| 361 |
run_on_click=True
|
|
|
|
| 365 |
cond_img_p.upload(
|
| 366 |
extract_img_feat,
|
| 367 |
[cond_img_p, img_state_points],
|
| 368 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
| 369 |
)
|
| 370 |
cond_img_bbox.upload(
|
| 371 |
extract_img_feat,
|
| 372 |
[cond_img_bbox, img_state_bbox],
|
| 373 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info]
|
| 374 |
)
|
| 375 |
|
| 376 |
# get user added points
|
| 377 |
cond_img_p.select(
|
| 378 |
get_points_with_draw,
|
| 379 |
[cond_img_p, img_state_points],
|
| 380 |
+
outputs=[img_state_points, cond_img_p]
|
| 381 |
).then(
|
| 382 |
segment_with_points,
|
| 383 |
inputs=[cond_img_p, img_state_points],
|
| 384 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
| 385 |
)
|
| 386 |
cond_img_bbox.select(
|
| 387 |
get_bbox_with_draw,
|
| 388 |
[cond_img_bbox, img_state_bbox],
|
| 389 |
+
outputs=[img_state_bbox, cond_img_bbox]
|
| 390 |
).then(
|
| 391 |
segment_with_bbox,
|
| 392 |
inputs=[cond_img_bbox, img_state_bbox],
|
| 393 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
| 394 |
)
|
| 395 |
|
| 396 |
# clean prompts
|
| 397 |
clean_btn_p.click(
|
| 398 |
clean_prompts,
|
| 399 |
inputs=[img_state_points],
|
| 400 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
| 401 |
)
|
| 402 |
clean_btn_bbox.click(
|
| 403 |
clean_prompts,
|
| 404 |
inputs=[img_state_bbox],
|
| 405 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
| 406 |
)
|
| 407 |
|
| 408 |
# clear
|
| 409 |
clear_btn_p.click(
|
| 410 |
clear_everything,
|
| 411 |
inputs=[img_state_points],
|
| 412 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
| 413 |
)
|
| 414 |
cond_img_p.clear(
|
| 415 |
clear_everything,
|
| 416 |
inputs=[img_state_points],
|
| 417 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
| 418 |
)
|
| 419 |
segm_img_p.clear(
|
| 420 |
clear_everything,
|
| 421 |
inputs=[img_state_points],
|
| 422 |
+
outputs=[img_state_points, cond_img_p, segm_img_p, cls_info]
|
| 423 |
)
|
| 424 |
clear_btn_bbox.click(
|
| 425 |
clear_everything,
|
| 426 |
inputs=[img_state_bbox],
|
| 427 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
| 428 |
)
|
| 429 |
cond_img_bbox.clear(
|
| 430 |
clear_everything,
|
| 431 |
inputs=[img_state_bbox],
|
| 432 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
| 433 |
)
|
| 434 |
segm_img_bbox.clear(
|
| 435 |
clear_everything,
|
| 436 |
inputs=[img_state_bbox],
|
| 437 |
+
outputs=[img_state_bbox, cond_img_bbox, segm_img_bbox, cls_info_bbox]
|
| 438 |
)
|
| 439 |
|
| 440 |
|