Spaces:
Runtime error
Runtime error
jens
commited on
Commit
·
d46e73c
1
Parent(s):
fb5f1fe
save raw_image and edited image
Browse files
app.py
CHANGED
@@ -24,6 +24,7 @@ with block:
|
|
24 |
def point_labels_empty():
|
25 |
return []
|
26 |
raw_image = gr.Image(type='pil', visible=False)
|
|
|
27 |
point_coords = gr.State(point_coords_empty)
|
28 |
point_labels = gr.State(point_labels_empty)
|
29 |
masks = gr.State([])
|
@@ -51,7 +52,7 @@ with block:
|
|
51 |
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
|
52 |
depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary')
|
53 |
# components
|
54 |
-
components = {point_coords, point_labels, raw_image, masks, cutout_idx, input_image,
|
55 |
point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
|
56 |
sam_decode_btn, depth_reconstruction_btn, masks_annotated_image, n_samples, cube_size}
|
57 |
|
@@ -69,7 +70,10 @@ with block:
|
|
69 |
input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
|
70 |
|
71 |
# event - set coords
|
72 |
-
def on_input_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
|
|
|
|
|
|
|
73 |
x, y = evt.index
|
74 |
color = red if point_label_radio == 0 else blue
|
75 |
img = np.array(input_image)
|
@@ -77,8 +81,14 @@ with block:
|
|
77 |
img = Image.fromarray(img)
|
78 |
point_coords.append([x,y])
|
79 |
point_labels.append(point_label_radio)
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
def on_click_sam_encode_btn(inputs):
|
84 |
print("encoding")
|
@@ -93,6 +103,7 @@ with block:
|
|
93 |
image = inputs[raw_image]
|
94 |
generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
|
95 |
inputs[masks].append((generated_mask, inputs[text]))
|
|
|
96 |
return {masks_annotated_image: (image, inputs[masks])}
|
97 |
sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True)
|
98 |
|
|
|
24 |
def point_labels_empty():
|
25 |
return []
|
26 |
raw_image = gr.Image(type='pil', visible=False)
|
27 |
+
image_edit_trigger = gr.State(True)
|
28 |
point_coords = gr.State(point_coords_empty)
|
29 |
point_labels = gr.State(point_labels_empty)
|
30 |
masks = gr.State([])
|
|
|
52 |
sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
|
53 |
depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary')
|
54 |
# components
|
55 |
+
components = {point_coords, point_labels, raw_image, image_edit_trigger, masks, cutout_idx, input_image,
|
56 |
point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
|
57 |
sam_decode_btn, depth_reconstruction_btn, masks_annotated_image, n_samples, cube_size}
|
58 |
|
|
|
70 |
input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
|
71 |
|
72 |
# event - set coords
|
73 |
+
def on_input_image_select(input_image, image_edit_trigger, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
|
74 |
+
if image_edit_trigger:
|
75 |
+
unedited_image = input_image.copy()
|
76 |
+
image_edit_trigger = False
|
77 |
x, y = evt.index
|
78 |
color = red if point_label_radio == 0 else blue
|
79 |
img = np.array(input_image)
|
|
|
81 |
img = Image.fromarray(img)
|
82 |
point_coords.append([x,y])
|
83 |
point_labels.append(point_label_radio)
|
84 |
+
|
85 |
+
return {raw_image: unedited_image,
|
86 |
+
input_image: img,
|
87 |
+
point_coords: point_coords,
|
88 |
+
point_labels: point_labels,
|
89 |
+
image_edit_trigger: image_edit_trigger}
|
90 |
+
|
91 |
+
input_image.select(on_input_image_select, [input_image, image_edit_trigger, point_coords, point_labels, point_label_radio], [input_image, raw_image, point_coords, point_labels, image_edit_trigger], queue=False)
|
92 |
|
93 |
def on_click_sam_encode_btn(inputs):
|
94 |
print("encoding")
|
|
|
103 |
image = inputs[raw_image]
|
104 |
generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
|
105 |
inputs[masks].append((generated_mask, inputs[text]))
|
106 |
+
print(inputs[masks][0])
|
107 |
return {masks_annotated_image: (image, inputs[masks])}
|
108 |
sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True)
|
109 |
|