not-lain commited on
Commit
a5944b1
·
verified ·
1 Parent(s): d272a54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -124,15 +124,16 @@ def rmbg(image=None, url=None):
124
 
125
  def mask_generation(image=None, d=None):
126
  d = eval(d) # convert this to dictionary
127
- predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
128
- predictor.set_image(image)
129
- input_point = np.array(d["input_points"])
130
- input_label = np.array(d["input_labels"])
131
- masks, scores, logits = predictor.predict(
132
- point_coords=input_point,
133
- point_labels=input_label,
134
- multimask_output=True,
135
- )
 
136
  sorted_ind = np.argsort(scores)[::-1]
137
  masks = masks[sorted_ind]
138
  scores = scores[sorted_ind]
 
124
 
125
  def mask_generation(image=None, d=None):
126
  d = eval(d) # convert this to dictionary
127
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
128
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-large")
129
+ predictor.set_image(image)
130
+ input_point = np.array(d["input_points"])
131
+ input_label = np.array(d["input_labels"])
132
+ masks, scores, logits = predictor.predict(
133
+ point_coords=input_point,
134
+ point_labels=input_label,
135
+ multimask_output=True,
136
+ )
137
  sorted_ind = np.argsort(scores)[::-1]
138
  masks = masks[sorted_ind]
139
  scores = scores[sorted_ind]