jens commited on
Commit
4d6f971
·
1 Parent(s): 60edd6a
Files changed (1) hide show
  1. app.py +24 -23
app.py CHANGED
@@ -28,7 +28,7 @@ with block:
28
  point_labels = gr.State(point_labels_empty)
29
  masks = gr.State([])
30
  cutout_idx = gr.State(set())
31
- everything_masks = gr.State([])
32
  prompt_masks = gr.State([])
33
 
34
  # UI
@@ -36,19 +36,20 @@ with block:
36
  with gr.Row():
37
  with gr.Column():
38
  input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False
39
- sam_encode_btn = gr.Button('Encode', variant='primary')
 
 
40
  #sam_encode_status = gr.Label('Not encoded yet')
41
  with gr.Row():
42
- with gr.Column():
43
- prompt_image = gr.Image(label='Segments')
44
- with gr.Row():
45
- point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
46
- text = gr.Textbox(label='Mask Name')
47
- reset_btn = gr.Button('New Mask')
48
  prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
49
- everything_image = gr.AnnotatedImage(label='Everything')
 
 
 
50
  selected_masks_image = gr.AnnotatedImage(label='Selected Masks')
51
- sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary')
 
52
 
53
  with gr.Row():
54
  with gr.Column():
@@ -72,7 +73,7 @@ with block:
72
  return input_image, point_coords_empty(), point_labels_empty(), None, []
73
  reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
74
 
75
- def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, prompt_masks, evt: gr.SelectData):
76
  x, y = evt.index
77
  color = red if point_label_radio == 0 else blue
78
  if prompt_image is None:
@@ -82,32 +83,32 @@ with block:
82
  point_coords.append([x,y])
83
  point_labels.append(point_label_radio)
84
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
85
- prompt_masks.append((generated_mask, text))
86
  return [ prompt_image,
87
  (input_image, [(generated_mask, "Mask")]),
88
  point_coords,
89
  point_labels,
90
- prompt_masks ]
91
 
92
  prompt_image.select(on_prompt_image_select,
93
- [input_image, prompt_image, point_coords, point_labels, point_label_radio, text, prompt_masks],
94
- [prompt_image, prompt_lbl_image, point_coords, point_labels, prompt_masks], queue=False)
95
 
96
- def on_everything_image_select(input_image, everything_masks, masks, text, evt: gr.SelectData):
97
  i = evt.index
98
- mask = everything_masks[i][0]
99
  print(mask)
100
  print(type(mask))
101
  masks.append((mask, text))
102
  anno = (input_image, masks)
103
  return [masks, anno]
104
 
105
- everything_image.select(on_everything_image_select,
106
- [input_image, everything_masks, masks, text],
107
- [masks, selected_masks_image], queue=False)
108
- prompt_lbl_image.select(on_everything_image_select,
109
- [input_image, prompt_masks, masks, text],
110
  [masks, selected_masks_image], queue=False)
 
 
 
111
 
112
 
113
  def on_click_sam_encode_btn(inputs):
@@ -141,7 +142,7 @@ with block:
141
  print(image)
142
  print(sam_masks)
143
  return [(image, sam_masks), sam_masks]
144
- sam_sgmt_everything_btn.click(on_sam_sgmt_everything_btn_click, components, [everything_image, everything_masks], queue=False)
145
 
146
 
147
  if __name__ == '__main__':
 
28
  point_labels = gr.State(point_labels_empty)
29
  masks = gr.State([])
30
  cutout_idx = gr.State(set())
31
+ pred_masks = gr.State([])
32
  prompt_masks = gr.State([])
33
 
34
  # UI
 
36
  with gr.Row():
37
  with gr.Column():
38
  input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False
39
+ with gr.Row():
40
+ sam_encode_btn = gr.Button('Encode', variant='primary')
41
+ sam_sgmt_everything_btn = gr.Button('Segment Everything!', variant = 'primary')
42
  #sam_encode_status = gr.Label('Not encoded yet')
43
  with gr.Row():
44
+ prompt_image = gr.Image(label='Segments')
 
 
 
 
 
45
  prompt_lbl_image = gr.AnnotatedImage(label='Segment Labels')
46
+ lbl_image = gr.AnnotatedImage(label='Everything')
47
+ point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
48
+ text = gr.Textbox(label='Mask Name')
49
+ reset_btn = gr.Button('New Mask')
50
  selected_masks_image = gr.AnnotatedImage(label='Selected Masks')
51
+
52
+
53
 
54
  with gr.Row():
55
  with gr.Column():
 
73
  return input_image, point_coords_empty(), point_labels_empty(), None, []
74
  reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
75
 
76
+ def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks, evt: gr.SelectData):
77
  x, y = evt.index
78
  color = red if point_label_radio == 0 else blue
79
  if prompt_image is None:
 
83
  point_coords.append([x,y])
84
  point_labels.append(point_label_radio)
85
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
86
+ pred_masks = [(generated_mask, text)]
87
  return [ prompt_image,
88
  (input_image, [(generated_mask, "Mask")]),
89
  point_coords,
90
  point_labels,
91
+ pred_masks ]
92
 
93
  prompt_image.select(on_prompt_image_select,
94
+ [input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks],
95
+ [prompt_image, lbl_image, point_coords, point_labels, pred_masks], queue=False)
96
 
97
+ def on_everything_image_select(input_image, pred_masks, masks, text, evt: gr.SelectData):
98
  i = evt.index
99
+ mask = pred_masks[i][0]
100
  print(mask)
101
  print(type(mask))
102
  masks.append((mask, text))
103
  anno = (input_image, masks)
104
  return [masks, anno]
105
 
106
+ lbl_image.select(on_everything_image_select,
107
+ [input_image, pred_masks, masks, text],
 
 
 
108
  [masks, selected_masks_image], queue=False)
109
+ #prompt_lbl_image.select(on_everything_image_select,
110
+ # [input_image, prompt_masks, masks, text],
111
+ # [masks, selected_masks_image], queue=False)
112
 
113
 
114
  def on_click_sam_encode_btn(inputs):
 
142
  print(image)
143
  print(sam_masks)
144
  return [(image, sam_masks), sam_masks]
145
+ sam_sgmt_everything_btn.click(on_sam_sgmt_everything_btn_click, components, [lbl_image, pred_masks], queue=False)
146
 
147
 
148
  if __name__ == '__main__':