jens commited on
Commit
0a54901
·
1 Parent(s): 5f53c0c

Select image from everything annotation

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -48,6 +48,7 @@ with block:
48
  n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples')
49
  cube_size = gr.Slider(minimum=0.000001, maximum=0.001, step=0.000001, default=0.00001, label='Cube size')
50
  with gr.Row():
 
51
  with gr.Column(scale=1):
52
  with gr.Row():
53
  point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
@@ -59,22 +60,13 @@ with block:
59
  # components
60
  components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image,
61
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
62
- sam_decode_btn, depth_reconstruction_btn, prompt_image, prompt_lbl_image, n_samples, cube_size}
63
 
64
  # event - init coords
65
  def on_reset_btn_click(input_image):
66
  return input_image, point_coords_empty(), point_labels_empty(), None, []
67
  reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
68
 
69
- #def on_input_image_upload(input_image):
70
- # print("encoding")
71
- # # encode image on upload
72
- ## sam.encode(input_image)
73
- # print("encoding done")
74
- # return input_image, point_coords_empty(), point_labels_empty(), None
75
- #input_image.upload(on_input_image_upload, [input_image], [input_image, point_coords, point_labels], queue=False)
76
-
77
- # event - set coords
78
  def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
79
  x, y = evt.index
80
  color = red if point_label_radio == 0 else blue
@@ -94,6 +86,18 @@ with block:
94
  prompt_image.select(on_prompt_image_select,
95
  [input_image, prompt_image, point_coords, point_labels, point_label_radio],
96
  [prompt_image, prompt_lbl_image, point_coords, point_labels], queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def on_click_sam_encode_btn(inputs):
99
  print("encoding")
 
48
  n_samples = gr.Slider(minimum=1e3, maximum=1e6, step=1e3, default=1e3, label='Number of Samples')
49
  cube_size = gr.Slider(minimum=0.000001, maximum=0.001, step=0.000001, default=0.00001, label='Cube size')
50
  with gr.Row():
51
+ selected_masks_image = gr.AnnotatedImage(label='Selected Masks')
52
  with gr.Column(scale=1):
53
  with gr.Row():
54
  point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
 
60
  # components
61
  components = {point_coords, point_labels, image_edit_trigger, masks, cutout_idx, input_image,
62
  point_label_radio, text, reset_btn, sam_sgmt_everything_btn,
63
+ sam_decode_btn, depth_reconstruction_btn, prompt_image, prompt_lbl_image, n_samples, cube_size, selected_masks_image}
64
 
65
  # event - init coords
66
  def on_reset_btn_click(input_image):
67
  return input_image, point_coords_empty(), point_labels_empty(), None, []
68
  reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
69
 
 
 
 
 
 
 
 
 
 
70
  def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
71
  x, y = evt.index
72
  color = red if point_label_radio == 0 else blue
 
86
  prompt_image.select(on_prompt_image_select,
87
  [input_image, prompt_image, point_coords, point_labels, point_label_radio],
88
  [prompt_image, prompt_lbl_image, point_coords, point_labels], queue=False)
89
+
90
+ def on_everything_image_select(input_image ,everything_image, masks, text, evt: gr.SelectData):
91
+ i = evt.index
92
+ mask = everything_image[1][i]
93
+ masks.append((mask, text))
94
+ anno = (input_image, masks)
95
+ return [ masks, anno]
96
+
97
+ prompt_image.select(on_everything_image_select,
98
+ [input_image, everything_image, masks, text],
99
+ [masks, selected_masks_image], queue=False)
100
+
101
 
102
  def on_click_sam_encode_btn(inputs):
103
  print("encoding")