s194649 commited on
Commit
cf01ea3
·
1 Parent(s): 5f453af
Files changed (1) hide show
  1. inference.py +21 -7
inference.py CHANGED
@@ -202,18 +202,32 @@ class CustomSamPredictor(SamPredictor):
202
  ) -> None:
203
  super().__init__(sam_model)
204
 
205
- def encode_image(self, image: np.ndarray, image_format: str = "RGB") -> torch.Tensor:
 
 
 
 
206
  """
207
- Encodes the image and returns its embedding.
 
208
 
209
  Arguments:
210
- image (np.ndarray): The image for which to calculate the embedding.
 
211
  image_format (str): The color format of the image, in ['RGB', 'BGR'].
212
-
213
- Returns:
214
- torch.Tensor: The image embedding with shape 1xCxHxW.
215
  """
216
- self.set_image(image, image_format)
 
 
 
 
 
 
 
 
 
 
 
217
  return self.get_image_embedding()
218
 
219
  def decode_and_predict(
 
202
  ) -> None:
203
  super().__init__(sam_model)
204
 
205
+ def encode_image(
206
+ self,
207
+ image: np.ndarray,
208
+ image_format: str = "RGB",
209
+ ) -> None:
210
  """
211
+ Calculates the image embeddings for the provided image, allowing
212
+ masks to be predicted with the 'predict' method.
213
 
214
  Arguments:
215
+ image (np.ndarray): The image for calculating masks. Expects an
216
+ image in HWC uint8 format, with pixel values in [0, 255].
217
  image_format (str): The color format of the image, in ['RGB', 'BGR'].
 
 
 
218
  """
219
+ assert image_format in [
220
+ "RGB",
221
+ "BGR",
222
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
223
+ if image_format != self.model.image_format:
224
+ image = image[..., ::-1]
225
+
226
+ # Transform the image to the form expected by the model
227
+ input_image = self.transform.apply_image(image)
228
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
229
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
230
+ self.set_torch_image(input_image_torch, image.shape[:2])
231
  return self.get_image_embedding()
232
 
233
  def decode_and_predict(