Update app.py
Browse files
app.py
CHANGED
@@ -124,15 +124,16 @@ def rmbg(image=None, url=None):
|
|
124 |
|
125 |
def mask_generation(image=None, d=None):
|
126 |
d = eval(d) # convert this to dictionary
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
136 |
sorted_ind = np.argsort(scores)[::-1]
|
137 |
masks = masks[sorted_ind]
|
138 |
scores = scores[sorted_ind]
|
|
|
124 |
|
125 |
def mask_generation(image=None, d=None):
|
126 |
d = eval(d) # convert this to dictionary
|
127 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
128 |
+
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
|
129 |
+
predictor.set_image(image)
|
130 |
+
input_point = np.array(d["input_points"])
|
131 |
+
input_label = np.array(d["input_labels"])
|
132 |
+
masks, scores, logits = predictor.predict(
|
133 |
+
point_coords=input_point,
|
134 |
+
point_labels=input_label,
|
135 |
+
multimask_output=True,
|
136 |
+
)
|
137 |
sorted_ind = np.argsort(scores)[::-1]
|
138 |
masks = masks[sorted_ind]
|
139 |
scores = scores[sorted_ind]
|