jens commited on
Commit
185ceb1
·
1 Parent(s): 8687af5
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -29,6 +29,7 @@ with block:
29
  masks = gr.State([])
30
  cutout_idx = gr.State(set())
31
  everything_masks = gr.State([])
 
32
 
33
  # UI
34
  with gr.Column():
@@ -68,7 +69,7 @@ with block:
68
  return input_image, point_coords_empty(), point_labels_empty(), None, []
69
  reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
70
 
71
- def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, evt: gr.SelectData):
72
  x, y = evt.index
73
  color = red if point_label_radio == 0 else blue
74
  if prompt_image is None:
@@ -78,15 +79,16 @@ with block:
78
  point_coords.append([x,y])
79
  point_labels.append(point_label_radio)
80
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
81
-
82
  return [ prompt_image,
83
  (input_image, [(generated_mask, "Mask")]),
84
  point_coords,
85
- point_labels ]
 
86
 
87
  prompt_image.select(on_prompt_image_select,
88
- [input_image, prompt_image, point_coords, point_labels, point_label_radio],
89
- [prompt_image, prompt_lbl_image, point_coords, point_labels], queue=False)
90
 
91
  def on_everything_image_select(input_image, everything_masks, masks, text, evt: gr.SelectData):
92
  i = evt.index
@@ -100,6 +102,9 @@ with block:
100
  everything_image.select(on_everything_image_select,
101
  [input_image, everything_masks, masks, text],
102
  [masks, selected_masks_image], queue=False)
 
 
 
103
 
104
 
105
  def on_click_sam_encode_btn(inputs):
 
29
  masks = gr.State([])
30
  cutout_idx = gr.State(set())
31
  everything_masks = gr.State([])
32
+ prompt_masks = gr.State([])
33
 
34
  # UI
35
  with gr.Column():
 
69
  return input_image, point_coords_empty(), point_labels_empty(), None, []
70
  reset_btn.click(on_reset_btn_click, [input_image], [input_image, point_coords, point_labels], queue=False)
71
 
72
+ def on_prompt_image_select(input_image, prompt_image, point_coords, point_labels, point_label_radio, text, prompt_masks, evt: gr.SelectData):
73
  x, y = evt.index
74
  color = red if point_label_radio == 0 else blue
75
  if prompt_image is None:
 
79
  point_coords.append([x,y])
80
  point_labels.append(point_label_radio)
81
  generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
82
+ prompt_masks.append((generated_mask, text))
83
  return [ prompt_image,
84
  (input_image, [(generated_mask, "Mask")]),
85
  point_coords,
86
+ point_labels,
87
+ prompt_masks ]
88
 
89
  prompt_image.select(on_prompt_image_select,
90
+ [input_image, prompt_image, point_coords, point_labels, point_label_radio, text, prompt_masks],
91
+ [prompt_image, prompt_lbl_image, point_coords, point_labels, prompt_masks], queue=False)
92
 
93
  def on_everything_image_select(input_image, everything_masks, masks, text, evt: gr.SelectData):
94
  i = evt.index
 
102
  everything_image.select(on_everything_image_select,
103
  [input_image, everything_masks, masks, text],
104
  [masks, selected_masks_image], queue=False)
105
+ prompt_lbl_image.select(on_everything_image_select,
106
+ [input_image, prompt_masks, masks, text],
107
+ [masks, selected_masks_image], queue=False)
108
 
109
 
110
  def on_click_sam_encode_btn(inputs):