saim1309 commited on
Commit
a4b3c40
·
verified ·
1 Parent(s): 8a9238b

Upload 10 files

Browse files
CellAware.cpython-312.pyc ADDED
Binary file (3.84 kB). View file
 
CellAware.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import copy
3
+
4
+ from monai.transforms import RandScaleIntensity, Compose
5
+ from monai.transforms.compose import MapTransform
6
+ from skimage.segmentation import find_boundaries
7
+
8
+
9
+ __all__ = ["BoundaryExclusion", "IntensityDiversification"]
10
+
11
+
12
+ class BoundaryExclusion(MapTransform):
13
+ """Map the cell boundary pixel labels to the background class (0)."""
14
+
15
+ def __init__(self, keys=["label"], allow_missing_keys=False):
16
+ super(BoundaryExclusion, self).__init__(keys, allow_missing_keys)
17
+
18
+ def __call__(self, data):
19
+ # Find and Exclude Boundary
20
+ label_original = data["label"]
21
+ label = copy.deepcopy(label_original)
22
+ boundary = find_boundaries(label, connectivity=1, mode="thick")
23
+ label[boundary] = 0
24
+
25
+ # Do not exclude if the cell is too small (< 14x14).
26
+ new_label = copy.deepcopy(label_original)
27
+ new_label[label == 0] = 0
28
+
29
+ cell_idx, cell_counts = np.unique(label_original, return_counts=True)
30
+
31
+ for k in range(len(cell_counts)):
32
+ if cell_counts[k] < 196:
33
+ new_label[label_original == cell_idx[k]] = cell_idx[k]
34
+
35
+ # Do not exclude if the pixels are at the image boundaries.
36
+ _, W, H = label_original.shape
37
+ bd = np.zeros_like(label_original, dtype=label.dtype)
38
+ bd[:, 2 : W - 2, 2 : H - 2] = 1
39
+ new_label += label_original * bd
40
+
41
+ # Assign the transformed label
42
+ data["label"] = new_label
43
+
44
+ return data
45
+
46
+
47
+ class IntensityDiversification(MapTransform):
48
+ """Randomly rescale the intensity of cell pixels."""
49
+
50
+ def __init__(
51
+ self,
52
+ keys=["img"],
53
+ change_cell_ratio=0.4,
54
+ scale_factors=[0, 0.7],
55
+ allow_missing_keys=False,
56
+ ):
57
+ super(IntensityDiversification, self).__init__(keys, allow_missing_keys)
58
+
59
+ self.change_cell_ratio = change_cell_ratio
60
+ self.randscale_intensity = Compose(
61
+ [RandScaleIntensity(prob=1.0, factors=scale_factors)]
62
+ )
63
+
64
+ def __call__(self, data):
65
+ # Select cells to be transformed
66
+ cell_count = int(data["label"].max())
67
+ change_cell_count = int(cell_count * self.change_cell_ratio)
68
+ change_cell = np.random.choice(cell_count, change_cell_count, replace=False)
69
+
70
+ mask = copy.deepcopy(data["label"])
71
+
72
+ for i in range(cell_count):
73
+ cell_id = i + 1
74
+
75
+ if cell_id not in change_cell:
76
+ mask[mask == cell_id] = 0
77
+
78
+ mask[mask > 0] = 1
79
+
80
+ # Conduct intensity transformation for the selected cells
81
+ img_original = copy.deepcopy((1 - mask) * data["img"])
82
+ img_transformed = copy.deepcopy(mask * data["img"])
83
+ img_transformed = self.randscale_intensity(img_transformed)
84
+
85
+ # Assign the transformed image
86
+ data["img"] = img_original + img_transformed
87
+
88
+ return data
LoadImage.cpython-312.pyc ADDED
Binary file (7 kB). View file
 
LoadImage.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tifffile as tif
3
+ import skimage.io as io
4
+ from typing import Optional, Sequence, Union
5
+ from monai.config import DtypeLike, PathLike, KeysCollection
6
+ from monai.utils import ensure_tuple
7
+ from monai.data.utils import is_supported_format, optional_import, ensure_tuple_rep
8
+ from monai.data.image_reader import ImageReader, NumpyReader
9
+ from monai.transforms import LoadImage, LoadImaged
10
+ from monai.utils.enums import PostFix
11
+
12
+ DEFAULT_POST_FIX = PostFix.meta()
13
+ itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
14
+
15
+ __all__ = [
16
+ "CustomLoadImaged",
17
+ "CustomLoadImageD",
18
+ "CustomLoadImageDict",
19
+ "CustomLoadImage",
20
+ ]
21
+
22
+
23
+ class CustomLoadImage(LoadImage):
24
+ """
25
+ Load image file or files from provided path based on reader.
26
+ If reader is not specified, this class automatically chooses readers
27
+ based on the supported suffixes and in the following order:
28
+
29
+ - User-specified reader at runtime when calling this loader.
30
+ - User-specified reader in the constructor of `LoadImage`.
31
+ - Readers from the last to the first in the registered list.
32
+ - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),
33
+ (npz, npy -> NumpyReader), (nrrd -> NrrdReader), (DICOM file -> ITKReader).
34
+
35
+ [!Caution] This overriding replaces the original ITK with Custom UnifiedITKReader.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ reader=None,
41
+ image_only: bool = False,
42
+ dtype: DtypeLike = np.float32,
43
+ ensure_channel_first: bool = False,
44
+ *args,
45
+ **kwargs,
46
+ ) -> None:
47
+ super(CustomLoadImage, self).__init__(
48
+ reader, image_only, dtype, ensure_channel_first, *args, **kwargs
49
+ )
50
+
51
+ # Adding TIFFReader. Although ITK Reader supports ".tiff" files, sometimes fails to load images.
52
+ self.readers = []
53
+ self.register(UnifiedITKReader(*args, **kwargs))
54
+
55
+
56
+ class CustomLoadImaged(LoadImaged):
57
+ """
58
+ Dictionary-based wrapper of `CustomLoadImage`.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ keys: KeysCollection,
64
+ reader: Optional[Union[ImageReader, str]] = None,
65
+ dtype: DtypeLike = np.float32,
66
+ meta_keys: Optional[KeysCollection] = None,
67
+ meta_key_postfix: str = DEFAULT_POST_FIX,
68
+ overwriting: bool = False,
69
+ image_only: bool = False,
70
+ ensure_channel_first: bool = False,
71
+ simple_keys=False,
72
+ allow_missing_keys: bool = False,
73
+ *args,
74
+ **kwargs,
75
+ ) -> None:
76
+ super(CustomLoadImaged, self).__init__(
77
+ keys,
78
+ reader,
79
+ dtype,
80
+ meta_keys,
81
+ meta_key_postfix,
82
+ overwriting,
83
+ image_only,
84
+ ensure_channel_first,
85
+ simple_keys,
86
+ allow_missing_keys,
87
+ *args,
88
+ **kwargs,
89
+ )
90
+
91
+ # Assign CustomLoader
92
+ self._loader = CustomLoadImage(
93
+ reader, image_only, dtype, ensure_channel_first, *args, **kwargs
94
+ )
95
+ if not isinstance(meta_key_postfix, str):
96
+ raise TypeError(
97
+ f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}."
98
+ )
99
+ self.meta_keys = (
100
+ ensure_tuple_rep(None, len(self.keys))
101
+ if meta_keys is None
102
+ else ensure_tuple(meta_keys)
103
+ )
104
+ if len(self.keys) != len(self.meta_keys):
105
+ raise ValueError("meta_keys should have the same length as keys.")
106
+ self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
107
+ self.overwriting = overwriting
108
+
109
+
110
+ class UnifiedITKReader(NumpyReader):
111
+ """
112
+ Unified Reader to read ".tif" and ".tiff files".
113
+ As the tifffile reads the images as numpy arrays, it inherits from the NumpyReader.
114
+ """
115
+
116
+ def __init__(
117
+ self, channel_dim: Optional[int] = None, **kwargs,
118
+ ):
119
+ super(UnifiedITKReader, self).__init__(channel_dim=channel_dim, **kwargs)
120
+ self.kwargs = kwargs
121
+ self.channel_dim = channel_dim
122
+
123
+ def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool:
124
+ """Verify whether the file format is supported by TIFF Reader."""
125
+
126
+ suffixes: Sequence[str] = ["tif", "tiff", "png", "jpg", "bmp", "jpeg",]
127
+ return has_itk or is_supported_format(filename, suffixes)
128
+
129
+ def read(self, data: Union[Sequence[PathLike], PathLike], **kwargs):
130
+ """Read Images from the file."""
131
+ img_ = []
132
+
133
+ filenames: Sequence[PathLike] = ensure_tuple(data)
134
+ kwargs_ = self.kwargs.copy()
135
+ kwargs_.update(kwargs)
136
+
137
+ for name in filenames:
138
+ name = f"{name}"
139
+
140
+ if name.endswith(".tif") or name.endswith(".tiff"):
141
+ _obj = tif.imread(name)
142
+ else:
143
+ try:
144
+ _obj = itk.imread(name, **kwargs_)
145
+ _obj = itk.array_view_from_image(_obj, keep_axes=False)
146
+ except:
147
+ _obj = io.imread(name)
148
+
149
+ if len(_obj.shape) == 2:
150
+ _obj = np.repeat(np.expand_dims(_obj, axis=-1), 3, axis=-1)
151
+ elif len(_obj.shape) == 3 and _obj.shape[-1] > 3:
152
+ _obj = _obj[:, :, :3]
153
+ else:
154
+ pass
155
+
156
+ img_.append(_obj)
157
+
158
+ return img_ if len(filenames) > 1 else img_[0]
159
+
160
+
161
+ CustomLoadImageD = CustomLoadImageDict = CustomLoadImaged
NormalizeImage.cpython-312.pyc ADDED
Binary file (3.82 kB). View file
 
NormalizeImage.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from skimage import exposure
3
+ from monai.config import KeysCollection
4
+
5
+ from monai.transforms.transform import Transform
6
+ from monai.transforms.compose import MapTransform
7
+
8
+ from typing import Dict, Hashable, Mapping
9
+
10
+
11
+ __all__ = [
12
+ "CustomNormalizeImage",
13
+ "CustomNormalizeImageD",
14
+ "CustomNormalizeImageDict",
15
+ "CustomNormalizeImaged",
16
+ ]
17
+
18
+
19
+ class CustomNormalizeImage(Transform):
20
+ """Normalize the image."""
21
+
22
+ def __init__(self, percentiles=[0, 99.5], channel_wise=False):
23
+ self.lower, self.upper = percentiles
24
+ self.channel_wise = channel_wise
25
+
26
+ def _normalize(self, img) -> np.ndarray:
27
+ non_zero_vals = img[np.nonzero(img)]
28
+ percentiles = np.percentile(non_zero_vals, [self.lower, self.upper])
29
+ img_norm = exposure.rescale_intensity(
30
+ img, in_range=(percentiles[0], percentiles[1]), out_range="uint8"
31
+ )
32
+
33
+ return img_norm.astype(np.uint8)
34
+
35
+ def __call__(self, img: np.ndarray) -> np.ndarray:
36
+ if self.channel_wise:
37
+ pre_img_data = np.zeros(img.shape, dtype=np.uint8)
38
+ for i in range(img.shape[-1]):
39
+ img_channel_i = img[:, :, i]
40
+
41
+ if len(img_channel_i[np.nonzero(img_channel_i)]) > 0:
42
+ pre_img_data[:, :, i] = self._normalize(img_channel_i)
43
+
44
+ img = pre_img_data
45
+
46
+ else:
47
+ img = self._normalize(img)
48
+
49
+ return img
50
+
51
+
52
+ class CustomNormalizeImaged(MapTransform):
53
+ """Dictionary-based wrapper of NormalizeImage"""
54
+
55
+ def __init__(
56
+ self,
57
+ keys: KeysCollection,
58
+ percentiles=[1, 99],
59
+ channel_wise: bool = False,
60
+ allow_missing_keys: bool = False,
61
+ ):
62
+ super(CustomNormalizeImageD, self).__init__(keys, allow_missing_keys)
63
+ self.normalizer = CustomNormalizeImage(percentiles, channel_wise)
64
+
65
+ def __call__(
66
+ self, data: Mapping[Hashable, np.ndarray]
67
+ ) -> Dict[Hashable, np.ndarray]:
68
+
69
+ d = dict(data)
70
+
71
+ for key in self.keys:
72
+ d[key] = self.normalizer(d[key])
73
+
74
+ return d
75
+
76
+
77
+ CustomNormalizeImageD = CustomNormalizeImageDict = CustomNormalizeImaged
__init__.cpython-312.pyc ADDED
Binary file (252 Bytes). View file
 
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .LoadImage import *
2
+ from .NormalizeImage import *
3
+ from .CellAware import *
modalities.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b50fb364519e5eafb29d7b2861c23a3420652f0ac55f28f0737307da39177c4
3
+ size 3762
transforms.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from CellAware import BoundaryExclusion, IntensityDiversification
2
+ from LoadImage import CustomLoadImaged,CustomLoadImageD,CustomLoadImageDict,CustomLoadImage
3
+ from NormalizeImage import CustomNormalizeImage,CustomNormalizeImageD,CustomNormalizeImageDict ,CustomNormalizeImaged
4
+
5
+ from monai.transforms import *
6
+
7
+ __all__ = [
8
+ "train_transforms",
9
+ "public_transforms",
10
+ "valid_transforms",
11
+ "tuning_transforms",
12
+ "unlabeled_transforms",
13
+ ]
14
+
15
+ train_transforms = Compose(
16
+ [
17
+ # >>> Load and refine data --- img: (H, W, 3); label: (H, W)
18
+ CustomLoadImaged(keys=["img", "label"], image_only=True),
19
+ CustomNormalizeImaged(
20
+ keys=["img"],
21
+ allow_missing_keys=True,
22
+ channel_wise=False,
23
+ percentiles=[0.0, 99.5],
24
+ ),
25
+ EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1),
26
+ RemoveRepeatedChanneld(keys=["label"], repeats=3), # label: (H, W)
27
+ ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label
28
+ # >>> Spatial transforms
29
+ RandZoomd(
30
+ keys=["img", "label"],
31
+ prob=0.5,
32
+ min_zoom=0.25,
33
+ max_zoom=1.5,
34
+ mode=["area", "nearest"],
35
+ keep_size=False,
36
+ ),
37
+ SpatialPadd(keys=["img", "label"], spatial_size=512),
38
+ RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False),
39
+ RandAxisFlipd(keys=["img", "label"], prob=0.5),
40
+ RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
41
+ IntensityDiversification(keys=["img", "label"], allow_missing_keys=True),
42
+ # # >>> Intensity transforms
43
+ RandGaussianNoised(keys=["img"], prob=0.25, mean=0, std=0.1),
44
+ RandAdjustContrastd(keys=["img"], prob=0.25, gamma=(1, 2)),
45
+ RandGaussianSmoothd(keys=["img"], prob=0.25, sigma_x=(1, 2)),
46
+ RandHistogramShiftd(keys=["img"], prob=0.25, num_control_points=3),
47
+ RandGaussianSharpend(keys=["img"], prob=0.25),
48
+ EnsureTyped(keys=["img", "label"]),
49
+ ]
50
+ )
51
+
52
+
53
+ public_transforms = Compose(
54
+ [
55
+ CustomLoadImaged(keys=["img", "label"], image_only=True),
56
+ BoundaryExclusion(keys=["label"]),
57
+ CustomNormalizeImaged(
58
+ keys=["img"],
59
+ allow_missing_keys=True,
60
+ channel_wise=False,
61
+ percentiles=[0.0, 99.5],
62
+ ),
63
+ EnsureChannelFirstd(keys=["img", "label"], channel_dim=-1),
64
+ RemoveRepeatedChanneld(keys=["label"], repeats=3), # label: (H, W)
65
+ ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label
66
+ # >>> Spatial transforms
67
+ SpatialPadd(keys=["img", "label"], spatial_size=512),
68
+ RandSpatialCropd(keys=["img", "label"], roi_size=512, random_size=False),
69
+ RandAxisFlipd(keys=["img", "label"], prob=0.5),
70
+ RandRotate90d(keys=["img", "label"], prob=0.5, spatial_axes=[0, 1]),
71
+ Rotate90d(k=1, keys=["label"], spatial_axes=(0, 1)),
72
+ Flipd(keys=["label"], spatial_axis=0),
73
+ EnsureTyped(keys=["img", "label"]),
74
+ ]
75
+ )
76
+
77
+
78
+ valid_transforms = Compose(
79
+ [
80
+ CustomLoadImaged(keys=["img", "label"], allow_missing_keys=True, image_only=True),
81
+ CustomNormalizeImaged(
82
+ keys=["img"],
83
+ allow_missing_keys=True,
84
+ channel_wise=False,
85
+ percentiles=[0.0, 99.5],
86
+ ),
87
+ EnsureChannelFirstd(keys=["img", "label"], allow_missing_keys=True, channel_dim=-1),
88
+ RemoveRepeatedChanneld(keys=["label"], repeats=3),
89
+ ScaleIntensityd(keys=["img"], allow_missing_keys=True),
90
+ EnsureTyped(keys=["img", "label"], allow_missing_keys=True),
91
+ ]
92
+ )
93
+
94
+ tuning_transforms = Compose(
95
+ [
96
+ CustomLoadImaged(keys=["img"], image_only=True),
97
+ CustomNormalizeImaged(
98
+ keys=["img"],
99
+ allow_missing_keys=True,
100
+ channel_wise=False,
101
+ percentiles=[0.0, 99.5],
102
+ ),
103
+ EnsureChannelFirstd(keys=["img"], channel_dim=-1),
104
+ ScaleIntensityd(keys=["img"]),
105
+ EnsureTyped(keys=["img"]),
106
+ ]
107
+ )
108
+
109
+ unlabeled_transforms = Compose(
110
+ [
111
+ # >>> Load and refine data --- img: (H, W, 3); label: (H, W)
112
+ CustomLoadImaged(keys=["img"], image_only=True),
113
+ CustomNormalizeImaged(
114
+ keys=["img"],
115
+ allow_missing_keys=True,
116
+ channel_wise=False,
117
+ percentiles=[0.0, 99.5],
118
+ ),
119
+ EnsureChannelFirstd(keys=["img"], channel_dim=-1),
120
+ RandZoomd(
121
+ keys=["img"],
122
+ prob=0.5,
123
+ min_zoom=0.25,
124
+ max_zoom=1.25,
125
+ mode=["area"],
126
+ keep_size=False,
127
+ ),
128
+ ScaleIntensityd(keys=["img"], allow_missing_keys=True), # Do not scale label
129
+ # >>> Spatial transforms
130
+ SpatialPadd(keys=["img"], spatial_size=512),
131
+ RandSpatialCropd(keys=["img"], roi_size=512, random_size=False),
132
+ EnsureTyped(keys=["img"]),
133
+ ]
134
+ )
135
+
136
+
137
+ def get_pred_transforms():
138
+ """Prediction preprocessing"""
139
+ pred_transforms = Compose(
140
+ [
141
+ # >>> Load and refine data
142
+ CustomLoadImage(image_only=True),
143
+ CustomNormalizeImage(channel_wise=False, percentiles=[0.0, 99.5]),
144
+ EnsureChannelFirst(channel_dim=-1), # image: (3, H, W)
145
+ ScaleIntensity(),
146
+ EnsureType(data_type="tensor"),
147
+ ]
148
+ )
149
+
150
+ return pred_transforms