jens commited on
Commit
1c4f487
·
1 Parent(s): 8c91851

Image instead of anno image

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -37,7 +37,7 @@ with block:
37
  input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False
38
  sam_encode_btn = gr.Button('Encode', variant='primary')
39
  sam_encode_status = gr.Label('Not encoded yet')
40
- masks_annotated_image = gr.AnnotatedImage(label='Segments')
41
  with gr.Column():
42
  pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0])
43
  n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples')
@@ -54,7 +54,7 @@ with block:
54
  # components
55
  components = {point_coords, point_labels, raw_image, image_edit_trigger, masks, cutout_idx, input_image,
56
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
57
- sam_decode_btn, depth_reconstruction_btn, masks_annotated_image, n_samples, cube_size}
58
 
59
  # event - init coords
60
  def on_reset_btn_click(raw_image):
@@ -86,16 +86,16 @@ with block:
86
  point_coords.append([x,y])
87
  point_labels.append(point_label_radio)
88
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
89
- return [(img, [(generated_mask, text)]),
90
  unedited_image,
91
  img,
92
  point_coords,
93
  point_labels,
94
  image_edit_trigger]
95
 
96
- masks_annotated_image.select(on_input_image_select,
97
  [raw_image, input_image, image_edit_trigger, point_coords, point_labels, point_label_radio],
98
- [masks_annotated_image, raw_image, input_image, point_coords, point_labels, image_edit_trigger], queue=False)
99
 
100
  def on_click_sam_encode_btn(inputs):
101
  print("encoding")
@@ -103,8 +103,8 @@ with block:
103
  sam.encode(inputs[input_image])
104
  print("encoding done")
105
  return {sam_encode_status: 'Image Encoded!',
106
- masks_annotated_image: (inputs[input_image], [])}
107
- sam_encode_btn.click(on_click_sam_encode_btn, components, [sam_encode_status, masks_annotated_image], queue=False)
108
 
109
  def on_click_sam_dencode_btn(inputs):
110
  print("inferencing")
@@ -112,8 +112,8 @@ with block:
112
  generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
113
  inputs[masks].append((generated_mask, inputs[text]))
114
  print(inputs[masks][0])
115
- return {masks_annotated_image: (image, inputs[masks])}
116
- sam_decode_btn.click(on_click_sam_dencode_btn, components, [masks_annotated_image, masks, cutout_idx], queue=True)
117
 
118
  def on_depth_reconstruction_btn_click(inputs):
119
  print("depth reconstruction")
 
37
  input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False
38
  sam_encode_btn = gr.Button('Encode', variant='primary')
39
  sam_encode_status = gr.Label('Not encoded yet')
40
+ prompt_image = gr.AnnotatedImage(label='Segments')
41
  with gr.Column():
42
  pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0])
43
  n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples')
 
54
  # components
55
  components = {point_coords, point_labels, raw_image, image_edit_trigger, masks, cutout_idx, input_image,
56
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
57
+ sam_decode_btn, depth_reconstruction_btn, prompt_image, n_samples, cube_size}
58
 
59
  # event - init coords
60
  def on_reset_btn_click(raw_image):
 
86
  point_coords.append([x,y])
87
  point_labels.append(point_label_radio)
88
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
89
+ return [ img,
90
  unedited_image,
91
  img,
92
  point_coords,
93
  point_labels,
94
  image_edit_trigger]
95
 
96
+ prompt_image.select(on_input_image_select,
97
  [raw_image, input_image, image_edit_trigger, point_coords, point_labels, point_label_radio],
98
+ [prompt_image, raw_image, input_image, point_coords, point_labels, image_edit_trigger], queue=False)
99
 
100
  def on_click_sam_encode_btn(inputs):
101
  print("encoding")
 
103
  sam.encode(inputs[input_image])
104
  print("encoding done")
105
  return {sam_encode_status: 'Image Encoded!',
106
+ prompt_image: inputs[input_image]}
107
+ sam_encode_btn.click(on_click_sam_encode_btn, components, [sam_encode_status, prompt_image], queue=False)
108
 
109
  def on_click_sam_dencode_btn(inputs):
110
  print("inferencing")
 
112
  generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
113
  inputs[masks].append((generated_mask, inputs[text]))
114
  print(inputs[masks][0])
115
+ return {prompt_image: (image, inputs[masks])}
116
+ sam_decode_btn.click(on_click_sam_dencode_btn, components, [prompt_image, masks, cutout_idx], queue=True)
117
 
118
  def on_depth_reconstruction_btn_click(inputs):
119
  print("depth reconstruction")