jens commited on
Commit
fe0db59
·
1 Parent(s): f5cfe4e
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -37,12 +37,9 @@ with block:
37
  sam_encode_btn = gr.Button('Encode', variant='primary')
38
  sam_encode_status = gr.Label('Not encoded yet')
39
  with gr.Row():
40
- with gr.Tab("Select with points"):
41
- with gr.Column():
42
- prompt_image = gr.Image(label='Segments')
43
- prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
44
 
45
-
46
  with gr.Tab("Select from segmented map"):
47
  everything_image = gr.AnnotatedImage(label='Everything')
48
 
@@ -64,7 +61,7 @@ with block:
64
  # components
65
  components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image,
66
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
67
- sam_decode_btn, depth_reconstruction_btn, prompt_image, n_samples, cube_size}
68
 
69
  # event - init coords
70
  def on_reset_btn_click(input_image):
@@ -83,15 +80,15 @@ with block:
83
  def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
84
  x, y = evt.index
85
  color = red if point_label_radio == 0 else blue
86
- if prompt is None:
87
- prompt = np.array(input_image.copy())
88
 
89
- cv2.circle(prompt, (x, y), 5, color, -1)
90
  point_coords.append([x,y])
91
  point_labels.append(point_label_radio)
92
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
93
 
94
- return [ prompt,
95
  (input_image, [(generated_mask, "Mask")]),
96
  point_coords,
97
  point_labels ]
 
37
  sam_encode_btn = gr.Button('Encode', variant='primary')
38
  sam_encode_status = gr.Label('Not encoded yet')
39
  with gr.Row():
40
+ prompt_image = gr.Image(label='Segments')
41
+ prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
 
 
42
 
 
43
  with gr.Tab("Select from segmented map"):
44
  everything_image = gr.AnnotatedImage(label='Everything')
45
 
 
61
  # components
62
  components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image,
63
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
64
+ sam_decode_btn, depth_reconstruction_btn, prompt_image, prompt_lbl_image, n_samples, cube_size}
65
 
66
  # event - init coords
67
  def on_reset_btn_click(input_image):
 
80
  def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
81
  x, y = evt.index
82
  color = red if point_label_radio == 0 else blue
83
+ if prompt_image is None:
84
+ prompt_image = np.array(input_image.copy())
85
 
86
+ cv2.circle(prompt_image, (x, y), 5, color, -1)
87
  point_coords.append([x,y])
88
  point_labels.append(point_label_radio)
89
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
90
 
91
+ return [ prompt_image,
92
  (input_image, [(generated_mask, "Mask")]),
93
  point_coords,
94
  point_labels ]