Spaces:
Runtime error
Runtime error
jens
commited on
Commit
·
711582a
1
Parent(s):
6046fb8
multi output
Browse files- app.py +3 -4
- inference.py +4 -1
app.py
CHANGED
@@ -80,13 +80,12 @@ with block:
|
|
80 |
cv2.circle(prompt_image, (x, y), 5, color, -1)
|
81 |
point_coords.append([x,y])
|
82 |
point_labels.append(point_label_radio)
|
83 |
-
|
84 |
-
pred_masks = [(generated_mask, text)]
|
85 |
return [ prompt_image,
|
86 |
-
(input_image,
|
87 |
point_coords,
|
88 |
point_labels,
|
89 |
-
|
90 |
|
91 |
prompt_image.select(on_prompt_image_select,
|
92 |
[input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks],
|
|
|
80 |
cv2.circle(prompt_image, (x, y), 5, color, -1)
|
81 |
point_coords.append([x,y])
|
82 |
point_labels.append(point_label_radio)
|
83 |
+
sam_masks = sam.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels))
|
|
|
84 |
return [ prompt_image,
|
85 |
+
(input_image, sam_masks),
|
86 |
point_coords,
|
87 |
point_labels,
|
88 |
+
sam_masks ]
|
89 |
|
90 |
prompt_image.select(on_prompt_image_select,
|
91 |
[input_image, prompt_image, point_coords, point_labels, point_label_radio, text, pred_masks],
|
inference.py
CHANGED
@@ -151,7 +151,10 @@ class SegmentPredictor:
|
|
151 |
point_labels=lbls,
|
152 |
multimask_output=True
|
153 |
)
|
154 |
-
|
|
|
|
|
|
|
155 |
|
156 |
|
157 |
def segment_everything(self, image):
|
|
|
151 |
point_labels=lbls,
|
152 |
multimask_output=True
|
153 |
)
|
154 |
+
sam_masks = []
|
155 |
+
for i,mask in enumerate(masks):
|
156 |
+
sam_masks.append((mask["segmentation"], str(i)))
|
157 |
+
return sam_masks
|
158 |
|
159 |
|
160 |
def segment_everything(self, image):
|