Upload model
Browse files- modeling.py +8 -1
modeling.py
CHANGED
@@ -349,6 +349,7 @@ class CTCropModel(PreTrainedModel):
|
|
349 |
mode: str,
|
350 |
device: str | None = None,
|
351 |
raw_hu: bool = False,
|
|
|
352 |
add_buffer: float | tuple[float, float] | None = None,
|
353 |
) -> np.ndarray:
|
354 |
assert mode in ["2d", "3d"]
|
@@ -375,7 +376,8 @@ class CTCropModel(PreTrainedModel):
|
|
375 |
coords = self.forward(x, img_shape=img_shapes, add_buffer=add_buffer)
|
376 |
# get the union of all slice-wise bounding boxes
|
377 |
# exclude empty boxes
|
378 |
-
|
|
|
379 |
# if all empty, return original input
|
380 |
if coords.shape[0] == 0:
|
381 |
print("no foreground detected, returning original input ...")
|
@@ -386,4 +388,9 @@ class CTCropModel(PreTrainedModel):
|
|
386 |
x1, y1 = x1.min().item(), y1.min().item()
|
387 |
x2, y2 = x2.max().item(), y2.max().item()
|
388 |
cropped = x0[:, y1:y2, x1:x2] if mode == "3d" else x0[y1:y2, x1:x2]
|
|
|
|
|
|
|
|
|
|
|
389 |
return cropped
|
|
|
349 |
mode: str,
|
350 |
device: str | None = None,
|
351 |
raw_hu: bool = False,
|
352 |
+
remove_empty_slices: bool = False,
|
353 |
add_buffer: float | tuple[float, float] | None = None,
|
354 |
) -> np.ndarray:
|
355 |
assert mode in ["2d", "3d"]
|
|
|
376 |
coords = self.forward(x, img_shape=img_shapes, add_buffer=add_buffer)
|
377 |
# get the union of all slice-wise bounding boxes
|
378 |
# exclude empty boxes
|
379 |
+
empty = coords.sum(dim=1) == 0
|
380 |
+
coords = coords[~empty]
|
381 |
# if all empty, return original input
|
382 |
if coords.shape[0] == 0:
|
383 |
print("no foreground detected, returning original input ...")
|
|
|
388 |
x1, y1 = x1.min().item(), y1.min().item()
|
389 |
x2, y2 = x2.max().item(), y2.max().item()
|
390 |
cropped = x0[:, y1:y2, x1:x2] if mode == "3d" else x0[y1:y2, x1:x2]
|
391 |
+
if remove_empty_slices and empty.sum() > 0:
|
392 |
+
empty_indices = list(torch.where(empty)[0].cpu().numpy())
|
393 |
+
print(f"removing {empty.sum()} empty slices ...")
|
394 |
+
cropped = cropped[~empty]
|
395 |
+
return cropped, empty_indices
|
396 |
return cropped
|