ianpan commited on
Commit
f18c485
·
verified ·
1 Parent(s): 61b5bfd

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +8 -1
modeling.py CHANGED
@@ -349,6 +349,7 @@ class CTCropModel(PreTrainedModel):
349
  mode: str,
350
  device: str | None = None,
351
  raw_hu: bool = False,
 
352
  add_buffer: float | tuple[float, float] | None = None,
353
  ) -> np.ndarray:
354
  assert mode in ["2d", "3d"]
@@ -375,7 +376,8 @@ class CTCropModel(PreTrainedModel):
375
  coords = self.forward(x, img_shape=img_shapes, add_buffer=add_buffer)
376
  # get the union of all slice-wise bounding boxes
377
  # exclude empty boxes
378
- coords = coords[coords.sum(dim=1) != 0]
 
379
  # if all empty, return original input
380
  if coords.shape[0] == 0:
381
  print("no foreground detected, returning original input ...")
@@ -386,4 +388,9 @@ class CTCropModel(PreTrainedModel):
386
  x1, y1 = x1.min().item(), y1.min().item()
387
  x2, y2 = x2.max().item(), y2.max().item()
388
  cropped = x0[:, y1:y2, x1:x2] if mode == "3d" else x0[y1:y2, x1:x2]
 
 
 
 
 
389
  return cropped
 
349
  mode: str,
350
  device: str | None = None,
351
  raw_hu: bool = False,
352
+ remove_empty_slices: bool = False,
353
  add_buffer: float | tuple[float, float] | None = None,
354
  ) -> np.ndarray:
355
  assert mode in ["2d", "3d"]
 
376
  coords = self.forward(x, img_shape=img_shapes, add_buffer=add_buffer)
377
  # get the union of all slice-wise bounding boxes
378
  # exclude empty boxes
379
+ empty = coords.sum(dim=1) == 0
380
+ coords = coords[~empty]
381
  # if all empty, return original input
382
  if coords.shape[0] == 0:
383
  print("no foreground detected, returning original input ...")
 
388
  x1, y1 = x1.min().item(), y1.min().item()
389
  x2, y2 = x2.max().item(), y2.max().item()
390
  cropped = x0[:, y1:y2, x1:x2] if mode == "3d" else x0[y1:y2, x1:x2]
391
+ if remove_empty_slices and empty.sum() > 0:
392
+ empty_indices = list(torch.where(empty)[0].cpu().numpy())
393
+ print(f"removing {empty.sum()} empty slices ...")
394
+ cropped = cropped[~empty]
395
+ return cropped, empty_indices
396
  return cropped