s194649 commited on
Commit
39f3339
·
1 Parent(s): bcdfff1
Files changed (2) hide show
  1. app.py +1 -0
  2. inference.py +61 -0
app.py CHANGED
@@ -142,6 +142,7 @@ with block:
142
  print("encoding")
143
  # encode image on click
144
  embedding = sam.encode(inputs[input_image]).cpu()
 
145
  print("encoding done")
146
  return [inputs[input_image], embedding]
147
  sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False)
 
142
  print("encoding")
143
  # encode image on click
144
  embedding = sam.encode(inputs[input_image]).cpu()
145
+ sam_cpu.dummy_encode(inputs[input_image])
146
  print("encoding done")
147
  return [inputs[input_image], embedding]
148
  sam_encode_btn.click(on_click_sam_encode_btn, components, [prompt_image, embedding], queue=False)
inference.py CHANGED
@@ -263,6 +263,63 @@ class CustomSamPredictor(SamPredictor):
263
  return_logits=return_logits,
264
  )
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  class SegmentPredictor:
268
  def __init__(self, device=None):
@@ -281,6 +338,10 @@ class SegmentPredictor:
281
  def encode(self, image):
282
  image = np.array(image)
283
  return self.conditioned_pred.encode_image(image)
 
 
 
 
284
 
285
  def cond_pred(self, embedding, pts, lbls):
286
  lbls = np.array(lbls)
 
263
  return_logits=return_logits,
264
  )
265
 
266
+ def dummy_set_torch_image(
267
+ self,
268
+ transformed_image: torch.Tensor,
269
+ original_image_size: Tuple[int, ...],
270
+ ) -> None:
271
+ """
272
+ Calculates the image embeddings for the provided image, allowing
273
+ masks to be predicted with the 'predict' method. Expects the input
274
+ image to be already transformed to the format expected by the model.
275
+
276
+ Arguments:
277
+ transformed_image (torch.Tensor): The input image, with shape
278
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
279
+ original_image_size (tuple(int, int)): The size of the image
280
+ before transformation, in (H, W) format.
281
+ """
282
+ assert (
283
+ len(transformed_image.shape) == 4
284
+ and transformed_image.shape[1] == 3
285
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
286
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
287
+ self.reset_image()
288
+
289
+ self.original_size = original_image_size
290
+ self.input_size = tuple(transformed_image.shape[-2:])
291
+ input_image = self.model.preprocess(transformed_image)
292
+ # The following line is commented out to avoid encoding on cpu
293
+ #self.features = self.model.image_encoder(input_image)
294
+ self.is_image_set = True
295
+
296
+ def dummy_set_image(
297
+ self,
298
+ image: np.ndarray,
299
+ image_format: str = "RGB",
300
+ ) -> None:
301
+ """
302
+ Calculates the image embeddings for the provided image, allowing
303
+ masks to be predicted with the 'predict' method.
304
+
305
+ Arguments:
306
+ image (np.ndarray): The image for calculating masks. Expects an
307
+ image in HWC uint8 format, with pixel values in [0, 255].
308
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
309
+ """
310
+ assert image_format in [
311
+ "RGB",
312
+ "BGR",
313
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
314
+ if image_format != self.model.image_format:
315
+ image = image[..., ::-1]
316
+
317
+ # Transform the image to the form expected by the model
318
+ input_image = self.transform.apply_image(image)
319
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
320
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
321
+
322
+ self.dummy_set_torch_image(input_image_torch, image.shape[:2])
323
 
324
  class SegmentPredictor:
325
  def __init__(self, device=None):
 
338
  def encode(self, image):
339
  image = np.array(image)
340
  return self.conditioned_pred.encode_image(image)
341
+
342
+ def dummy_encode(self, image):
343
+ image = np.array(image)
344
+ self.conditioned_pred.dummy_set_image(image)
345
 
346
  def cond_pred(self, embedding, pts, lbls):
347
  lbls = np.array(lbls)