jens commited on
Commit
43dcd18
·
1 Parent(s): 95eb778

fix tab and sort mask by size

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. inference.py +3 -6
app.py CHANGED
@@ -41,8 +41,9 @@ with block:
41
  )
42
  with gr.Row():
43
  with gr.Column():
44
- with gr.Tab():
45
  input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False
 
46
  input_image = gr.Image(label='Input', type='pil', tool=None, source="webcam") # mirror_webcam = False
47
  with gr.Row():
48
  sam_encode_btn = gr.Button('Encode', variant='primary')
 
41
  )
42
  with gr.Row():
43
  with gr.Column():
44
+ with gr.Tab("Upload Image"):
45
  input_image = gr.Image(label='Input', type='pil', tool=None) # mirror_webcam = False
46
+ with gr.Tab("Webcam"):
47
  input_image = gr.Image(label='Input', type='pil', tool=None, source="webcam") # mirror_webcam = False
48
  with gr.Row():
49
  sam_encode_btn = gr.Button('Encode', variant='primary')
inference.py CHANGED
@@ -184,18 +184,15 @@ class SegmentPredictor:
184
  def cond_pred(self, pts, lbls):
185
  lbls = np.array(lbls)
186
  pts = np.array(pts)
187
-
188
- print(pts)
189
- print(lbls)
190
-
191
  masks, _, _ = self.conditioned_pred.predict(
192
  point_coords=pts,
193
  point_labels=lbls,
194
  multimask_output=True
195
  )
 
196
  sam_masks = []
197
- for i,mask in enumerate(masks):
198
- sam_masks.append((mask, str(i)))
199
  return sam_masks
200
 
201
 
 
184
  def cond_pred(self, pts, lbls):
185
  lbls = np.array(lbls)
186
  pts = np.array(pts)
 
 
 
 
187
  masks, _, _ = self.conditioned_pred.predict(
188
  point_coords=pts,
189
  point_labels=lbls,
190
  multimask_output=True
191
  )
192
+ idxs = np.argsort(masks.sum(axis=(1,2)))
193
  sam_masks = []
194
+ for n,i in enumerate(idxs):
195
+ sam_masks.append((masks[i], str(n)))
196
  return sam_masks
197
 
198