Spaces:
Runtime error
Runtime error
update code
Browse files
app.py
CHANGED
@@ -88,8 +88,8 @@ def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float =
|
|
88 |
|
89 |
def inference(ic_image, ic_mask, image1, image2):
|
90 |
# in context image and mask
|
91 |
-
ic_image =
|
92 |
-
|
93 |
|
94 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
95 |
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
@@ -114,7 +114,7 @@ def inference(ic_image, ic_mask, image1, image2):
|
|
114 |
|
115 |
for test_image in [image1, image2]:
|
116 |
print("======> Testing Image" )
|
117 |
-
test_image =
|
118 |
|
119 |
# Image feature encoding
|
120 |
predictor.set_image(test_image)
|
@@ -188,8 +188,8 @@ def inference_scribble(image, image1, image2):
|
|
188 |
# in context image and mask
|
189 |
ic_image = image["image"]
|
190 |
ic_mask = image["mask"]
|
191 |
-
ic_image =
|
192 |
-
|
193 |
|
194 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
195 |
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
@@ -214,7 +214,7 @@ def inference_scribble(image, image1, image2):
|
|
214 |
|
215 |
for test_image in [image1, image2]:
|
216 |
print("======> Testing Image" )
|
217 |
-
test_image =
|
218 |
|
219 |
# Image feature encoding
|
220 |
predictor.set_image(test_image)
|
@@ -286,8 +286,8 @@ def inference_scribble(image, image1, image2):
|
|
286 |
|
287 |
def inference_finetune(ic_image, ic_mask, image1, image2):
|
288 |
# in context image and mask
|
289 |
-
ic_image =
|
290 |
-
|
291 |
|
292 |
gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
|
293 |
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
|
@@ -377,7 +377,7 @@ def inference_finetune(ic_image, ic_mask, image1, image2):
|
|
377 |
output_image = []
|
378 |
|
379 |
for test_image in [image1, image2]:
|
380 |
-
test_image =
|
381 |
|
382 |
# Image feature encoding
|
383 |
predictor.set_image(test_image)
|
@@ -466,14 +466,14 @@ description = """
|
|
466 |
main = gr.Interface(
|
467 |
fn=inference,
|
468 |
inputs=[
|
469 |
-
gr.Image(label="in context image",),
|
470 |
-
gr.Image(label="in context mask"),
|
471 |
-
gr.Image(label="test image1"),
|
472 |
-
gr.Image(label="test image2"),
|
473 |
],
|
474 |
outputs=[
|
475 |
-
gr.Image(label="output image1").style(height=256, width=256),
|
476 |
-
gr.Image(label="output image2").style(height=256, width=256),
|
477 |
],
|
478 |
allow_flagging="never",
|
479 |
cache_examples=False,
|
@@ -490,13 +490,13 @@ main = gr.Interface(
|
|
490 |
main_scribble = gr.Interface(
|
491 |
fn=inference_scribble,
|
492 |
inputs=[
|
493 |
-
gr.ImageMask(label="[Stroke] Draw on Image"),
|
494 |
-
gr.Image(label="test image1"),
|
495 |
-
gr.Image(label="test image2"),
|
496 |
],
|
497 |
outputs=[
|
498 |
-
gr.Image(label="output image1").style(height=256, width=256),
|
499 |
-
gr.Image(label="output image2").style(height=256, width=256),
|
500 |
],
|
501 |
allow_flagging="never",
|
502 |
cache_examples=False,
|
@@ -510,17 +510,18 @@ main_scribble = gr.Interface(
|
|
510 |
)
|
511 |
"""
|
512 |
|
|
|
513 |
main_finetune = gr.Interface(
|
514 |
fn=inference_finetune,
|
515 |
inputs=[
|
516 |
-
gr.Image(label="in context image",),
|
517 |
-
gr.Image(label="in context mask"),
|
518 |
-
gr.Image(label="test image1"),
|
519 |
-
gr.Image(label="test image2"),
|
520 |
],
|
521 |
outputs=[
|
522 |
-
gr.Image(label="output image1").style(height=256, width=256),
|
523 |
-
gr.Image(label="output image2").style(height=256, width=256),
|
524 |
],
|
525 |
allow_flagging="never",
|
526 |
cache_examples=False,
|
|
|
88 |
|
89 |
def inference(ic_image, ic_mask, image1, image2):
|
90 |
# in context image and mask
|
91 |
+
ic_image = np.array(ic_image.convert("RGB"))
|
92 |
+
ic_mask = np.array(ic_mask.convert("RGB"))
|
93 |
|
94 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
95 |
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
|
|
114 |
|
115 |
for test_image in [image1, image2]:
|
116 |
print("======> Testing Image" )
|
117 |
+
test_image = np.array(test_image.convert("RGB"))
|
118 |
|
119 |
# Image feature encoding
|
120 |
predictor.set_image(test_image)
|
|
|
188 |
# in context image and mask
|
189 |
ic_image = image["image"]
|
190 |
ic_mask = image["mask"]
|
191 |
+
ic_image = np.array(ic_image.convert("RGB"))
|
192 |
+
ic_mask = np.array(ic_mask.convert("RGB"))
|
193 |
|
194 |
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
|
195 |
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
|
|
|
214 |
|
215 |
for test_image in [image1, image2]:
|
216 |
print("======> Testing Image" )
|
217 |
+
test_image = np.array(test_image.convert("RGB"))
|
218 |
|
219 |
# Image feature encoding
|
220 |
predictor.set_image(test_image)
|
|
|
286 |
|
287 |
def inference_finetune(ic_image, ic_mask, image1, image2):
|
288 |
# in context image and mask
|
289 |
+
ic_image = np.array(ic_image.convert("RGB"))
|
290 |
+
ic_mask = np.array(ic_mask.convert("RGB"))
|
291 |
|
292 |
gt_mask = torch.tensor(ic_mask)[:, :, 0] > 0
|
293 |
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
|
|
|
377 |
output_image = []
|
378 |
|
379 |
for test_image in [image1, image2]:
|
380 |
+
test_image = np.array(test_image.convert("RGB"))
|
381 |
|
382 |
# Image feature encoding
|
383 |
predictor.set_image(test_image)
|
|
|
466 |
main = gr.Interface(
|
467 |
fn=inference,
|
468 |
inputs=[
|
469 |
+
gr.Image(label="in context image", type='pil'),
|
470 |
+
gr.Image(label="in context mask", type='pil'),
|
471 |
+
gr.Image(label="test image1", type='pil'),
|
472 |
+
gr.Image(label="test image2", type='pil'),
|
473 |
],
|
474 |
outputs=[
|
475 |
+
gr.Image(label="output image1", type='pil').style(height=256, width=256),
|
476 |
+
gr.Image(label="output image2", type='pil').style(height=256, width=256),
|
477 |
],
|
478 |
allow_flagging="never",
|
479 |
cache_examples=False,
|
|
|
490 |
main_scribble = gr.Interface(
|
491 |
fn=inference_scribble,
|
492 |
inputs=[
|
493 |
+
gr.ImageMask(label="[Stroke] Draw on Image", brush_radius=4, type='pil'),
|
494 |
+
gr.Image(label="test image1", type='pil'),
|
495 |
+
gr.Image(label="test image2", type='pil'),
|
496 |
],
|
497 |
outputs=[
|
498 |
+
gr.Image(label="output image1", type='pil').style(height=256, width=256),
|
499 |
+
gr.Image(label="output image2", type='pil').style(height=256, width=256),
|
500 |
],
|
501 |
allow_flagging="never",
|
502 |
cache_examples=False,
|
|
|
510 |
)
|
511 |
"""
|
512 |
|
513 |
+
|
514 |
main_finetune = gr.Interface(
|
515 |
fn=inference_finetune,
|
516 |
inputs=[
|
517 |
+
gr.Image(label="in context image", type='pil'),
|
518 |
+
gr.Image(label="in context mask", type='pil'),
|
519 |
+
gr.Image(label="test image1", type='pil'),
|
520 |
+
gr.Image(label="test image2", type='pil'),
|
521 |
],
|
522 |
outputs=[
|
523 |
+
gr.Image(label="output image1", type='pil').style(height=256, width=256),
|
524 |
+
gr.Image(label="output image2", type='pil').style(height=256, width=256),
|
525 |
],
|
526 |
allow_flagging="never",
|
527 |
cache_examples=False,
|