jens commited on
Commit
711582a
·
1 Parent(s): 6046fb8

multi output

Browse files
Files changed (2) hide show
  1. app.py +3 -4
  2. inference.py +4 -1
app.py CHANGED
@@ -80,13 +80,12 @@ with block:
80
  cv2.circle(prompt_image, (x, y), 5, color, -1)
81
  point_coords.append([x,y])
82
  point_labels.append(point_label_radio)
83
- generated_mask, _, _ = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
84
- pred_masks = [(generated_mask, text)]
85
  return [ prompt_image,
86
- (input_image, [(generated_mask, "Mask")]),
87
  point_coords,
88
  point_labels,
89
- pred_masks ]
90
 
91
  prompt_image.select(on_prompt_image_select,
92
  [input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks],
 
80
  cv2.circle(prompt_image, (x, y), 5, color, -1)
81
  point_coords.append([x,y])
82
  point_labels.append(point_label_radio)
83
+ sam_masks = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
 
84
  return [ prompt_image,
85
+ (input_image, sam_masks),
86
  point_coords,
87
  point_labels,
88
+ sam_masks ]
89
 
90
  prompt_image.select(on_prompt_image_select,
91
  [input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks],
inference.py CHANGED
@@ -151,7 +151,10 @@ class SegmentPredictor:
151
  point_labels=lbls,
152
  multimask_output=True
153
  )
154
- return masks
 
 
 
155
 
156
 
157
  def segment_everything(self, image):
 
151
  point_labels=lbls,
152
  multimask_output=True
153
  )
154
+ sam_masks = []
155
+ for i,mask in enumerate(masks):
156
+ sam_masks.append((mask["segmentation"], str(i)))
157
+ return sam_masks
158
 
159
 
160
  def segment_everything(self, image):