jens commited on
Commit
9c32f6d
·
1 Parent(s): cf01ea3
Files changed (2) hide show
  1. app.py +1 -1
  2. inference.py +2 -1
app.py CHANGED
@@ -99,7 +99,7 @@ with block:
99
  cv2.circle(prompt_image, (x, y), 5, color, -1)
100
  point_coords.append([x,y])
101
  point_labels.append(point_label_radio)
102
- sam_masks = sam_cpu.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding)
103
  return [ prompt_image,
104
  (input_image, sam_masks),
105
  point_coords,
 
99
  cv2.circle(prompt_image, (x, y), 5, color, -1)
100
  point_coords.append([x,y])
101
  point_labels.append(point_label_radio)
102
+ sam_masks = sam_cpu.cond_pred(pts=np.array(point_coords), lbls=np.array(point_labels), embedding=embedding, image_size=input_image.size)
103
  return [ prompt_image,
104
  (input_image, sam_masks),
105
  point_coords,
inference.py CHANGED
@@ -232,6 +232,7 @@ class CustomSamPredictor(SamPredictor):
232
 
233
  def decode_and_predict(
234
  self,
 
235
  embedding: torch.Tensor,
236
  point_coords: Optional[np.ndarray] = None,
237
  point_labels: Optional[np.ndarray] = None,
@@ -252,7 +253,7 @@ class CustomSamPredictor(SamPredictor):
252
  (np.ndarray): An array of quality predictions for each mask.
253
  (np.ndarray): Low resolution mask logits for subsequent iterations.
254
  """
255
- self.set_torch_image(embedding, (embedding.shape[-2], embedding.shape[-1]))
256
  return self.predict(
257
  point_coords=point_coords,
258
  point_labels=point_labels,
 
232
 
233
  def decode_and_predict(
234
  self,
235
+ image_size: Tuple[int, int],
236
  embedding: torch.Tensor,
237
  point_coords: Optional[np.ndarray] = None,
238
  point_labels: Optional[np.ndarray] = None,
 
253
  (np.ndarray): An array of quality predictions for each mask.
254
  (np.ndarray): Low resolution mask logits for subsequent iterations.
255
  """
256
+ self.features = embedding
257
  return self.predict(
258
  point_coords=point_coords,
259
  point_labels=point_labels,