File size: 2,756 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
from typing import List, Tuple
from monai.transforms import Compose, AddChannelD, MaskIntensityD, DeleteItemsD, CropForegroundD, ResizeD

class SelectMaskByLevelD:
  """
  Selects a mask segment from a mask image based on a given level index. May also
  be applied to a single channel.
  """
  def __init__(self, mask_key: str, level_idx_key: str):
    self.mask_key = mask_key
    self.level_idx_key = level_idx_key

  def __call__(self, data):
    d = dict(data)
    mask = np.zeros_like(d[self.mask_key])
    mask[d[self.mask_key] == d[self.level_idx_key]] = 1
    d[self.mask_key] = mask
    return d

def get_mask_transform(hparams, loaded_keys: List[str], level_idx_key='level_idx') -> Tuple[Compose, List[str]]:
    """
    Depending on the configuration values for 'MASK', the transform returned by this method does one of the following:
        - nothing ('none')
        - applies the mask of the critical vertebra to the image ('apply')
        - applies the mask of all visible vertebrae to the image ('apply_all')
        - loads the mask into the 'mask' key s.t. it will later be stacked with the image ('channel')
        - crop the image to the critical vertebra and upsample it ('crop')
    """

    if hparams.mask == 'none':
        return Compose([]), loaded_keys

    assert len(loaded_keys) == 2
    image_key, mask_key = loaded_keys

    if hparams.mask == 'apply':
        return Compose([
            # only select relevant vertebra
            SelectMaskByLevelD(mask_key=mask_key, level_idx_key=level_idx_key),
            # apply mask
            MaskIntensityD(keys=image_key, mask_key=mask_key),
            # once the mask is applied, release it
            DeleteItemsD(keys=mask_key),
        ]), [image_key]

    elif hparams.mask == 'apply_all':
        return Compose([
            # keeps all vertebra in the mask
            # apply mask
            MaskIntensityD(keys=image_key, mask_key=mask_key),
            # once the mask is applied, release it
            DeleteItemsD(keys=mask_key),
        ]), [image_key]

    elif hparams.mask == 'channel':
        return Compose([
            SelectMaskByLevelD(mask_key=mask_key, level_idx_key=level_idx_key),
        ]), loaded_keys

    elif hparams.mask == 'crop':
        # TODO CropForegroundD ignores one spatial dimension, thus not truly cropping
        return Compose([
            SelectMaskByLevelD(mask_key=mask_key, level_idx_key=level_idx_key),
            CropForegroundD(keys=image_key, source_key=mask_key, margin=2),
            DeleteItemsD(keys=mask_key),
            AddChannelD(keys=image_key),
            ResizeD(keys=image_key, spatial_size=[hparams.input_size] * hparams.input_dim, mode='trilinear'),
        ]), [image_key]