Update app.py
Browse files
app.py
CHANGED
|
@@ -124,15 +124,16 @@ def rmbg(image=None, url=None):
|
|
| 124 |
|
| 125 |
def mask_generation(image=None, d=None):
|
| 126 |
d = eval(d) # convert this to dictionary
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
| 136 |
sorted_ind = np.argsort(scores)[::-1]
|
| 137 |
masks = masks[sorted_ind]
|
| 138 |
scores = scores[sorted_ind]
|
|
|
|
| 124 |
|
| 125 |
def mask_generation(image=None, d=None):
|
| 126 |
d = eval(d) # convert this to dictionary
|
| 127 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 128 |
+
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
|
| 129 |
+
predictor.set_image(image)
|
| 130 |
+
input_point = np.array(d["input_points"])
|
| 131 |
+
input_label = np.array(d["input_labels"])
|
| 132 |
+
masks, scores, logits = predictor.predict(
|
| 133 |
+
point_coords=input_point,
|
| 134 |
+
point_labels=input_label,
|
| 135 |
+
multimask_output=True,
|
| 136 |
+
)
|
| 137 |
sorted_ind = np.argsort(scores)[::-1]
|
| 138 |
masks = masks[sorted_ind]
|
| 139 |
scores = scores[sorted_ind]
|