jens commited on
Commit
0579ca3
·
1 Parent(s): f76bf44

UI update + removed "raw_image"

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -23,7 +23,6 @@ with block:
23
  return []
24
  def point_labels_empty():
25
  return []
26
- raw_image = gr.Image(type='pil', visible=False)
27
  image_edit_trigger = gr.State(True)
28
  point_coords = gr.State(point_coords_empty)
29
  point_labels = gr.State(point_labels_empty)
@@ -38,10 +37,12 @@ with block:
38
  sam_encode_btn = gr.Button('Encode', variant='primary')
39
  sam_encode_status = gr.Label('Not encoded yet')
40
  prompt_image = gr.Image(label='Segments')
 
41
  with gr.Column():
42
  pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0])
43
- n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples')
44
- cube_size = gr.Slider(minimum=0.000001, maximum=0.001, step=0.000001, default=0.00001, label='Cube size')
 
45
  with gr.Row():
46
  with gr.Column(scale=1):
47
  with gr.Row():
@@ -52,14 +53,14 @@ with block:
52
  sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
53
  depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary')
54
  # components
55
- components = {point_coords, point_labels, raw_image, image_edit_trigger, masks, cutout_idx, input_image,
56
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
57
  sam_decode_btn, depth_reconstruction_btn, prompt_image, n_samples, cube_size}
58
 
59
  # event - init coords
60
- def on_reset_btn_click(raw_image):
61
- return raw_image, point_coords_empty(), point_labels_empty(), None, []
62
- reset_btn.click(on_reset_btn_click, [raw_image], [input_image, point_coords, point_labels], queue=False)
63
 
64
  def on_input_image_upload(input_image):
65
  print("encoding")
@@ -67,7 +68,7 @@ with block:
67
  sam.encode(input_image)
68
  print("encoding done")
69
  return input_image, point_coords_empty(), point_labels_empty(), None
70
- input_image.upload(on_input_image_upload, [input_image], [raw_image, point_coords, point_labels], queue=False)
71
 
72
  # event - set coords
73
  def on_prompt_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
@@ -100,7 +101,7 @@ with block:
100
 
101
  def on_click_sam_dencode_btn(inputs):
102
  print("inferencing")
103
- image = inputs[raw_image]
104
  generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
105
  inputs[masks].append((generated_mask, inputs[text]))
106
  print(inputs[masks][0])
@@ -109,7 +110,7 @@ with block:
109
 
110
  def on_depth_reconstruction_btn_click(inputs):
111
  print("depth reconstruction")
112
- image = inputs[raw_image]
113
  path = dpt.generate_obj_masks(image=image, n_samples=inputs[n_samples], cube_size=inputs[cube_size], masks=inputs[masks])
114
  return {pcl_figure: path}
115
  depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False)
 
23
  return []
24
  def point_labels_empty():
25
  return []
 
26
  image_edit_trigger = gr.State(True)
27
  point_coords = gr.State(point_coords_empty)
28
  point_labels = gr.State(point_labels_empty)
 
37
  sam_encode_btn = gr.Button('Encode', variant='primary')
38
  sam_encode_status = gr.Label('Not encoded yet')
39
  prompt_image = gr.Image(label='Segments')
40
+ with gr.Row():
41
  with gr.Column():
42
  pcl_figure = gr.Model3D(label="3-D Reconstruction", clear_color=[1.0, 1.0, 1.0, 1.0])
43
+ with gr.Row():
44
+ n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples')
45
+ cube_size = gr.Slider(minimum=0.000001, maximum=0.001, step=0.000001, default=0.00001, label='Cube size')
46
  with gr.Row():
47
  with gr.Column(scale=1):
48
  with gr.Row():
 
53
  sam_decode_btn = gr.Button('Predict using points!', variant = 'primary')
54
  depth_reconstruction_btn = gr.Button('Depth Reconstruction', variant = 'primary')
55
  # components
56
+ components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image,
57
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
58
  sam_decode_btn, depth_reconstruction_btn, prompt_image, n_samples, cube_size}
59
 
60
  # event - init coords
61
+ def on_reset_btn_click(input_image):
62
+ return input_image, point_coords_empty(), point_labels_empty(), None, []
63
+ reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
64
 
65
  def on_input_image_upload(input_image):
66
  print("encoding")
 
68
  sam.encode(input_image)
69
  print("encoding done")
70
  return input_image, point_coords_empty(), point_labels_empty(), None
71
+ input_image.upload(on_input_image_upload, [input_image], [input_image, point_coords, point_labels], queue=False)
72
 
73
  # event - set coords
74
  def on_prompt_image_select(input_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
 
101
 
102
  def on_click_sam_dencode_btn(inputs):
103
  print("inferencing")
104
+ image = inputs[input_image]
105
  generated_mask, _, _ = sam.cond_pred(pts=np.array(inputs[point_coords]), lbls=np.array(inputs[point_labels]))
106
  inputs[masks].append((generated_mask, inputs[text]))
107
  print(inputs[masks][0])
 
110
 
111
  def on_depth_reconstruction_btn_click(inputs):
112
  print("depth reconstruction")
113
+ image = inputs[input_image]
114
  path = dpt.generate_obj_masks(image=image, n_samples=inputs[n_samples], cube_size=inputs[cube_size], masks=inputs[masks])
115
  return {pcl_figure: path}
116
  depth_reconstruction_btn.click(on_depth_reconstruction_btn_click, components, [pcl_figure], queue=False)