Spaces:
Runtime error
Runtime error
normalisation for patch-infer fixed
Browse files
app.py
CHANGED
|
@@ -48,8 +48,18 @@ def infer_full_vol(tensor, model):
|
|
| 48 |
|
| 49 |
def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_width=32, stride_depth=16, batch_size=10, num_worker=2):
|
| 50 |
test_subject = tio.Subject(img = tio.ScalarImage(tensor=tensor.unsqueeze(0))) # adding channel dim while creating the TorchIO subject
|
|
|
|
| 51 |
overlap = np.subtract(patch_size, (stride_length, stride_width, stride_depth))
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
with torch.no_grad():
|
| 54 |
grid_sampler = tio.inference.GridSampler(
|
| 55 |
test_subject,
|
|
@@ -63,7 +73,7 @@ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_wid
|
|
| 63 |
for i, patches_batch in enumerate(patch_loader):
|
| 64 |
st.text(f"Processing batch {i + 1} of {total_batches}...")
|
| 65 |
|
| 66 |
-
local_batch = patches_batch['img'][tio.DATA].float()
|
| 67 |
local_batch = local_batch / local_batch.max()
|
| 68 |
locations = patches_batch[tio.LOCATION]
|
| 69 |
|
|
|
|
| 48 |
|
| 49 |
def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_width=32, stride_depth=16, batch_size=10, num_worker=2):
|
| 50 |
test_subject = tio.Subject(img = tio.ScalarImage(tensor=tensor.unsqueeze(0))) # adding channel dim while creating the TorchIO subject
|
| 51 |
+
|
| 52 |
overlap = np.subtract(patch_size, (stride_length, stride_width, stride_depth))
|
| 53 |
|
| 54 |
+
def normaliser(batch):
|
| 55 |
+
"""
|
| 56 |
+
Purpose: Normalise pixel intensities of each patch using the max values in the 3D patch
|
| 57 |
+
:param batch: 5D array (batch_size x channel x width x depth x height)
|
| 58 |
+
"""
|
| 59 |
+
for i in range(batch.shape[0]):
|
| 60 |
+
batch[i] = batch[i] / batch[i].max()
|
| 61 |
+
return batch
|
| 62 |
+
|
| 63 |
with torch.no_grad():
|
| 64 |
grid_sampler = tio.inference.GridSampler(
|
| 65 |
test_subject,
|
|
|
|
| 73 |
for i, patches_batch in enumerate(patch_loader):
|
| 74 |
st.text(f"Processing batch {i + 1} of {total_batches}...")
|
| 75 |
|
| 76 |
+
local_batch = normaliser(patches_batch['img'][tio.DATA].float())
|
| 77 |
local_batch = local_batch / local_batch.max()
|
| 78 |
locations = patches_batch[tio.LOCATION]
|
| 79 |
|