Spaces:
Runtime error
Runtime error
s194649
commited on
Commit
·
cf01ea3
1
Parent(s):
5f453af
fix
Browse files- inference.py +21 -7
inference.py
CHANGED
@@ -202,18 +202,32 @@ class CustomSamPredictor(SamPredictor):
|
|
202 |
) -> None:
|
203 |
super().__init__(sam_model)
|
204 |
|
205 |
-
def encode_image(
|
|
|
|
|
|
|
|
|
206 |
"""
|
207 |
-
|
|
|
208 |
|
209 |
Arguments:
|
210 |
-
image (np.ndarray): The image for
|
|
|
211 |
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
212 |
-
|
213 |
-
Returns:
|
214 |
-
torch.Tensor: The image embedding with shape 1xCxHxW.
|
215 |
"""
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
return self.get_image_embedding()
|
218 |
|
219 |
def decode_and_predict(
|
|
|
202 |
) -> None:
|
203 |
super().__init__(sam_model)
|
204 |
|
205 |
+
def encode_image(
|
206 |
+
self,
|
207 |
+
image: np.ndarray,
|
208 |
+
image_format: str = "RGB",
|
209 |
+
) -> None:
|
210 |
"""
|
211 |
+
Calculates the image embeddings for the provided image, allowing
|
212 |
+
masks to be predicted with the 'predict' method.
|
213 |
|
214 |
Arguments:
|
215 |
+
image (np.ndarray): The image for calculating masks. Expects an
|
216 |
+
image in HWC uint8 format, with pixel values in [0, 255].
|
217 |
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
|
|
|
|
|
|
218 |
"""
|
219 |
+
assert image_format in [
|
220 |
+
"RGB",
|
221 |
+
"BGR",
|
222 |
+
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
223 |
+
if image_format != self.model.image_format:
|
224 |
+
image = image[..., ::-1]
|
225 |
+
|
226 |
+
# Transform the image to the form expected by the model
|
227 |
+
input_image = self.transform.apply_image(image)
|
228 |
+
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
229 |
+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
230 |
+
self.set_torch_image(input_image_torch, image.shape[:2])
|
231 |
return self.get_image_embedding()
|
232 |
|
233 |
def decode_and_predict(
|