jens commited on
Commit
d46e73c
·
1 Parent(s): fb5f1fe

save raw_image and edited image

Browse files
Files changed (1) hide show
  1. app.py +15 -4
app.py CHANGED
@@ -24,6 +24,7 @@ with block:
24
  def point_labels_empty():
25
  return []
26
  raw_image = gr.Image(type='pil', visible=False)
 
27
  point_coords = gr.State(point_coords_empty)
28
  point_labels = gr.State(point_labels_empty)
29
  masks = gr.State([])
@@ -51,7 +52,7 @@ with block:
51
  sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
52
  depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary')
53
  # components
54
- components = {point_coords, point_labels, raw_image, masks, cutout_idx, input_image,
55
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
56
  sam_decode_btn, depth_reconstruction_btn, masks_annotated_image, n_samples, cube_size}
57
 
@@ -69,7 +70,10 @@ with block:
69
  input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
70
 
71
  # event - set coords
72
- def on_input_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
 
 
 
73
  x, y = evt.index
74
  color = red if point_label_radio == 0 else blue
75
  img = np.array(input_image)
@@ -77,8 +81,14 @@ with block:
77
  img = Image.fromarray(img)
78
  point_coords.append([x,y])
79
  point_labels.append(point_label_radio)
80
- return img, point_coords, point_labels
81
- input_image.select(on_input_image_select, [input_image, point_coords, point_labels, point_label_radio], [input_image, point_coords, point_labels], queue=False)
 
 
 
 
 
 
82
 
83
  def on_click_sam_encode_btn(inputs):
84
  print("encoding")
@@ -93,6 +103,7 @@ with block:
93
  image = inputs[raw_image]
94
  generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
95
  inputs[masks].append((generated_mask, inputs[text]))
 
96
  return {masks_annotated_image: (image, inputs[masks])}
97
  sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True)
98
 
 
24
  def point_labels_empty():
25
  return []
26
  raw_image = gr.Image(type='pil', visible=False)
27
+ image_edit_trigger = gr.State(True)
28
  point_coords = gr.State(point_coords_empty)
29
  point_labels = gr.State(point_labels_empty)
30
  masks = gr.State([])
 
52
  sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
53
  depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary')
54
  # components
55
+ components = {point_coords, point_labels, raw_image, image_edit_trigger, masks, cutout_idx, input_image,
56
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
57
  sam_decode_btn, depth_reconstruction_btn, masks_annotated_image, n_samples, cube_size}
58
 
 
70
  input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
71
 
72
  # event - set coords
73
+ def on_input_image_select(input_image, image_edit_trigger, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
74
+ if image_edit_trigger:
75
+ unedited_image = input_image.copy()
76
+ image_edit_trigger = False
77
  x, y = evt.index
78
  color = red if point_label_radio == 0 else blue
79
  img = np.array(input_image)
 
81
  img = Image.fromarray(img)
82
  point_coords.append([x,y])
83
  point_labels.append(point_label_radio)
84
+
85
+ return {raw_image: unedited_image,
86
+ input_image: img,
87
+ point_coords: point_coords,
88
+ point_labels: point_labels,
89
+ image_edit_trigger: image_edit_trigger}
90
+
91
+ input_image.select(on_input_image_select, [input_image, image_edit_trigger, point_coords, point_labels, point_label_radio], [input_image, raw_image, point_coords, point_labels, image_edit_trigger], queue=False)
92
 
93
  def on_click_sam_encode_btn(inputs):
94
  print("encoding")
 
103
  image = inputs[raw_image]
104
  generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
105
  inputs[masks].append((generated_mask, inputs[text]))
106
+ print(inputs[masks][0])
107
  return {masks_annotated_image: (image, inputs[masks])}
108
  sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True)
109