Spaces:
Runtime error
Runtime error
jens
commited on
Commit
·
185ceb1
1
Parent(s):
8687af5
fix
Browse files
app.py
CHANGED
@@ -29,6 +29,7 @@ with block:
|
|
29 |
masks = gr.State([])
|
30 |
cutout_idx = gr.State(set())
|
31 |
everything_masks = gr.State([])
|
|
|
32 |
|
33 |
# UI
|
34 |
with gr.Column():
|
@@ -68,7 +69,7 @@ with block:
|
|
68 |
return input_image, point_coords_empty(), point_labels_empty(), None, []
|
69 |
reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
|
70 |
|
71 |
-
def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
|
72 |
x, y = evt.index
|
73 |
color = red if point_label_radio == 0 else blue
|
74 |
if prompt_image is None:
|
@@ -78,15 +79,16 @@ with block:
|
|
78 |
point_coords.append([x,y])
|
79 |
point_labels.append(point_label_radio)
|
80 |
generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
|
81 |
-
|
82 |
return [ prompt_image,
|
83 |
(input_image, [(generated_mask, "Mask")]),
|
84 |
point_coords,
|
85 |
-
point_labels
|
|
|
86 |
|
87 |
prompt_image.select(on_prompt_image_select,
|
88 |
-
[input_image, prompt_image, point_coords, point_labels, point_label_radio],
|
89 |
-
[prompt_image, prompt_lbl_image, point_coords, point_labels], queue=False)
|
90 |
|
91 |
def on_everything_image_select(input_image, everything_masks, masks, text, evt: gr.SelectData):
|
92 |
i = evt.index
|
@@ -100,6 +102,9 @@ with block:
|
|
100 |
everything_image.select(on_everything_image_select,
|
101 |
[input_image, everything_masks, masks, text],
|
102 |
[masks, selected_masks_image], queue=False)
|
|
|
|
|
|
|
103 |
|
104 |
|
105 |
def on_click_sam_encode_btn(inputs):
|
|
|
29 |
masks = gr.State([])
|
30 |
cutout_idx = gr.State(set())
|
31 |
everything_masks = gr.State([])
|
32 |
+
prompt_masks = gr.State([])
|
33 |
|
34 |
# UI
|
35 |
with gr.Column():
|
|
|
69 |
return input_image, point_coords_empty(), point_labels_empty(), None, []
|
70 |
reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
|
71 |
|
72 |
+
def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, prompt_masks, evt: gr.SelectData):
|
73 |
x, y = evt.index
|
74 |
color = red if point_label_radio == 0 else blue
|
75 |
if prompt_image is None:
|
|
|
79 |
point_coords.append([x,y])
|
80 |
point_labels.append(point_label_radio)
|
81 |
generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
|
82 |
+
prompt_masks.append((generated_mask, text))
|
83 |
return [ prompt_image,
|
84 |
(input_image, [(generated_mask, "Mask")]),
|
85 |
point_coords,
|
86 |
+
point_labels,
|
87 |
+
prompt_masks ]
|
88 |
|
89 |
prompt_image.select(on_prompt_image_select,
|
90 |
+
[input_image, prompt_image, point_coords, point_labels, point_label_radio, text, prompt_masks],
|
91 |
+
[prompt_image, prompt_lbl_image, point_coords, point_labels, prompt_masks], queue=False)
|
92 |
|
93 |
def on_everything_image_select(input_image, everything_masks, masks, text, evt: gr.SelectData):
|
94 |
i = evt.index
|
|
|
102 |
everything_image.select(on_everything_image_select,
|
103 |
[input_image, everything_masks, masks, text],
|
104 |
[masks, selected_masks_image], queue=False)
|
105 |
+
prompt_lbl_image.select(on_everything_image_select,
|
106 |
+
[input_image, prompt_masks, masks, text],
|
107 |
+
[masks, selected_masks_image], queue=False)
|
108 |
|
109 |
|
110 |
def on_click_sam_encode_btn(inputs):
|