willis commited on
Commit
290ca27
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 aiaudit.org
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![MIT License](https://img.shields.io/apm/l/atomic-design-ui.svg?)](https://github.com/tterb/atomic-design-ui/blob/master/LICENSEs) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5235536.svg)](https://doi.org/10.5281/zenodo.5235536)
2
+
3
+ # From Lens to Logit - Addressing Camera Hardware-Drift Using Raw Sensor Data
4
+
5
+ *This repository hosts the code for the project ["From Lens to Logit: Addressing Camera Hardware-Drift Using Raw Sensor Data"](https://openreview.net/forum?id=DRAywM1BhU), submitted to the NeurIPS 2021 Datasets and Benchmarks Track.*
6
+
7
+ In order to address camera hardware-drift we require two ingredients: raw sensor data and an image processing model. This code repository contains the materials for the second ingredient, the image processing model, as well as scripts to load lada and run experiments. For a conceptual overview of the project we reocommend the [project site](https://aiaudit.org/lens2logit/) or the [full paper](https://openreview.net/forum?id=DRAywM1BhU).
8
+
9
+ ## A short introduction
10
+ ![L2L Overview](https://user-images.githubusercontent.com/38631399/131536063-585cf9b0-e76e-4e41-a05e-2fcf4902f539.png)
11
+
12
+
13
+ To create an image, raw sensor data traverses complex image signal processing pipelines. These pipelines are used by cameras and scientific instruments to produce the images fed into machine learning systems. The processing pipelines vary by device, influencing the resulting image statistics and ultimately contributing to what is known as hardware-drift. However, this processing is rarely considered in machine learning modelling, because available benchmark data sets are generally not in raw format. Here we show that pairing qualified raw sensor data with an explicit, differentiable model of the image processing pipeline allows to tackle camera hardware-drift.
14
+
15
+ Specifically, we demonstrate
16
+ 1. the **controlled synthesis of hardware-drift test cases**
17
+ 2. modular **hardware-drift forensics**, as well as
18
+ 3. **image processing customization**.
19
+
20
+ We make available two data sets.
21
+ 1. **Raw-Microscopy**, contains
22
+ * **940 raw bright-field microscopy images** of human blood smear slides for leukocyte classification alongside
23
+ * **5,640 variations measured at six different intensities** and twelve additional sets totalling
24
+ * **11,280 images of the raw sensor data processed through different pipelines**.
25
+ 3. **Raw-Drone**, contains
26
+ * **548 raw drone camera images** for car segmentation, alongside
27
+ * **3,288 variations measured at six different intensities** and also twelve additional sets totalling
28
+ * **6,576 images of the raw sensor data processed through different pipelines**.
29
+ ## Data access
30
+ If you use our code you can use the convenient cloud storage integration. Data will be loaded automatically from a cloud storage bucket and stored to your working machine. You can find the code snippet doing that [here](https://github.com/aiaudit-org/lens2logit/blob/f8a165a0c094456f68086167f0bef14c3b311a4e/utils/base.py#L130)
31
+
32
+ ```python
33
+ def get_b2_bucket():
34
+ bucket_name = 'perturbed-minds'
35
+ application_key_id = '003d6b042de536a0000000008'
36
+ application_key = 'K003HMNxnoa91Dy9c0V8JVCKNUnwR9U'
37
+ info = InMemoryAccountInfo()
38
+ b2_api = B2Api(info)
39
+ b2_api.authorize_account('production', application_key_id, application_key)
40
+ bucket = b2_api.get_bucket_by_name(bucket_name)
41
+ return bucket
42
+ ```
43
+ We also maintain a copy of the entire dataset with a permanent identifier at Zenodo which you can find here [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5235536.svg)](https://doi.org/10.5281/zenodo.5235536).
44
+ ## Code
45
+ ### Dependencies
46
+ #### Conda environment and dependencies
47
+ To run this code out-of-the-box you can install the latest project conda environment stored in `environment.yml`
48
+ ```console
49
+ $ conda env create -f environment.yml
50
+ ```
51
+ #### segmentation_models_pytorch newest version
52
+ We noticed that PyPi package for `segmentation_models_pytorch` is sometimes behind the project's github repository. If you encounter `smp` related problems we reccomend installing directly from the `smp` reposiroty via
53
+ ```console
54
+ $ python -m pip install git+https://github.com/qubvel/segmentation_models.pytorch
55
+ ```
56
+ ### Recreate experiments
57
+ The central file for using the **Lens2Logit** framework for experiments as in the paper is `train.py` which provides a rich set of arguments to experiment with raw image data, different image processing models and task models for regression or classification. Below we provide three example prompts for the type of experiments reported in the [paper](https://openreview.net/forum?id=DRAywM1BhU)
58
+ #### Controlled synthesis of hardware-drift test cases
59
+ ```console
60
+ $ python train.py \
61
+ --experiment_name YOUR-EXPERIMENT-NAME \
62
+ --run_name YOUR-RUN-NAME \
63
+ --dataset Microscopy \
64
+ --lr 1e-5 \
65
+ --n_splits 5 \
66
+ --epochs 5 \
67
+ --classifier_pretrained \
68
+ --processing_mode static \
69
+ --augmentation weak \
70
+ --log_model True \
71
+ --iso 0.01 \
72
+ --freeze_processor \
73
+ --processor_uri "$processor_uri" \
74
+ --track_processing \
75
+ --track_every_epoch \
76
+ --track_predictions \
77
+ --track_processing_gradients \
78
+ --track_save_tensors \
79
+ ```
80
+ #### Modular hardware-drift forensics
81
+ ```console
82
+ $ python train.py \
83
+ --experiment_name YOUR-EXPERIMENT-NAME \
84
+ --run_name YOUR-RUN-NAME \
85
+ --dataset Microscopy \
86
+ --adv_training
87
+ --lr 1e-5 \
88
+ --n_splits 5 \
89
+ --epochs 5 \
90
+ --classifier_pretrained \
91
+ --processing_mode parametrized \
92
+ --augmentation weak \
93
+ --log_model True \
94
+ --iso 0.01 \
95
+ --track_processing \
96
+ --track_every_epoch \
97
+ --track_predictions \
98
+ --track_processing_gradients \
99
+ --track_save_tensors \
100
+ ```
101
+ #### Image processing customization
102
+ ```console
103
+ $ python train.py \
104
+ --experiment_name YOUR-EXPERIMENT-NAME \
105
+ --run_name YOUR-RUN-NAME \
106
+ --dataset Microscopy \
107
+ --lr 1e-5 \
108
+ --n_splits 5 \
109
+ --epochs 5 \
110
+ --classifier_pretrained \
111
+ --processing_mode parametrized \
112
+ --augmentation weak \
113
+ --log_model True \
114
+ --iso 0.01 \
115
+ --track_processing \
116
+ --track_every_epoch \
117
+ --track_predictions \
118
+ --track_processing_gradients \
119
+ --track_save_tensors \
120
+ ```
121
+ ## Virtual lab log
122
+ We maintain a collaborative virtual lab log at [this address](http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com/#/). There you can browse experiment runs, analyze results through SQL queries and download trained processing and task models.
123
+ ![mlflow](https://user-images.githubusercontent.com/38631399/131536233-f6b6e0ae-35f2-4ee0-a5e2-d04f8efb8d73.png)
124
+
125
+
126
+ ### Review our experiments
127
+ Experiments are listed in the left column. You can select individual runs or compare metrics and parameters across different runs. For runs where we tracked images of intermediate processing steps and images of the gradients at these processing steps you can find at the bottom of a run page in the *results* folder for each epoch.
128
+ ### Use our trained models
129
+ When selecting a run and a model was saved you can find the model files, state dictionary and instructions to load at the bottom of a run page under *models*. In the menu bar at the top of the virtual lab log you can also access models via the *Model Registry*. Our code is well integrated with the *mlflow* autologging and -loading package for PyTorch. So when using our code you can just specify the *model uri* as an argument and models will be fetched from the model registry automatically.
dataset.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import rawpy
4
+ import random
5
+ from PIL import Image
6
+ import tifffile as tiff
7
+ import zipfile
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
13
+ from sklearn.model_selection import StratifiedShuffleSplit
14
+
15
+ if not os.path.exists('README.md'): # set pwd to root
16
+ os.chdir('..')
17
+
18
+ from utils.dataset_utils import split_img, list_images_in_dir, load_image
19
+ from utils.base import np2torch, torch2np, b2_download_folder
20
+
21
+ IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff']
22
+
23
+
24
+ def get_dataset(name, I_ratio=1.0):
25
+ # DroneDataset
26
+ if name in ('DC', 'Drone', 'DroneClassification', 'DroneDatasetClassificationTiled'):
27
+ return DroneDatasetClassificationTiled(I_ratio=I_ratio)
28
+ if name in ('DS', 'DroneSegmentation', 'DroneDatasetSegmentationTiled'):
29
+ return DroneDatasetSegmentationTiled(I_ratio=I_ratio)
30
+
31
+ # MicroscopyDataset
32
+ if name in ('M', 'Microscopy', 'MicroscopyDataset'):
33
+ return MicroscopyDataset(I_ratio=I_ratio)
34
+
35
+ # for testing
36
+ if name in ('DSF', 'DroneDatasetSegmentationFull'):
37
+ return DroneDatasetSegmentationFull(I_ratio=I_ratio)
38
+ if name in ('MRGB', 'MicroscopyRGB', 'MicroscopyDatasetRGB'):
39
+ return MicroscopyDatasetRGB(I_ratio=I_ratio)
40
+
41
+ raise ValueError(name)
42
+
43
+
44
+ class ImageFolderDataset(Dataset):
45
+ """Creates a dataset of images in img_dir and corresponding masks in mask_dir.
46
+ Corresponding mask files need to contain the filename of the image.
47
+ Files are expected to be of the same filetype.
48
+
49
+ Args:
50
+ img_dir (str): path to image folder
51
+ mask_dir (str): path to mask folder
52
+ transform (callable, optional): transformation to apply to image and mask
53
+ bits (int, optional): normalize image by dividing by 2^bits - 1
54
+ """
55
+
56
+ task = 'classification'
57
+
58
+ def __init__(self, img_dir, labels, transform=None, bits=1):
59
+
60
+ self.img_dir = img_dir
61
+ self.labels = labels
62
+
63
+ self.images = list_images_in_dir(img_dir)
64
+
65
+ assert len(self.images) == len(self.labels)
66
+
67
+ self.transform = transform
68
+ self.bits = bits
69
+
70
+ def __repr__(self):
71
+ rep = f"{type(self).__name__}: ImageFolderDataset[{len(self.images)}]"
72
+ for n, (img, label) in enumerate(zip(self.images, self.labels)):
73
+ rep += f'\nimage: {img}\tlabel: {label}'
74
+ if n > 10:
75
+ rep += '\n...'
76
+ break
77
+ return rep
78
+
79
+ def __len__(self):
80
+ return len(self.images)
81
+
82
+ def __getitem__(self, idx):
83
+
84
+ label = self.labels[idx]
85
+
86
+ img = load_image(self.images[idx])
87
+ img = img / (2**self.bits - 1)
88
+ if self.transform is not None:
89
+ img = self.transform(img)
90
+
91
+ if len(img.shape) == 2:
92
+ assert img.shape == (256, 256), f"Invalid size for {self.images[idx]}"
93
+ else:
94
+ assert img.shape == (3, 256, 256), f"Invalid size for {self.images[idx]}"
95
+
96
+ return img, label
97
+
98
+
99
+ class ImageFolderDatasetSegmentation(Dataset):
100
+ """Creates a dataset of images in `img_dir` and corresponding masks in `mask_dir`.
101
+ Corresponding mask files need to contain the filename of the image.
102
+ Files are expected to be of the same filetype.
103
+
104
+ Args:
105
+ img_dir (str): path to image folder
106
+ mask_dir (str): path to mask folder
107
+ transform (callable, optional): transformation to apply to image and mask
108
+ bits (int, optional): normalize image by dividing by 2^bits - 1
109
+ """
110
+
111
+ task = 'segmentation'
112
+
113
+ def __init__(self, img_dir, mask_dir, transform=None, bits=1):
114
+
115
+ self.img_dir = img_dir
116
+ self.mask_dir = mask_dir
117
+
118
+ self.images = list_images_in_dir(img_dir)
119
+ self.masks = list_images_in_dir(mask_dir)
120
+
121
+ check_image_folder_consistency(self.images, self.masks)
122
+
123
+ self.transform = transform
124
+ self.bits = bits
125
+
126
+ def __repr__(self):
127
+ rep = f"{type(self).__name__}: ImageFolderDatasetSegmentation[{len(self.images)}]"
128
+ for n, (img, mask) in enumerate(zip(self.images, self.masks)):
129
+ rep += f'\nimage: {img}\tmask: {mask}'
130
+ if n > 10:
131
+ rep += '\n...'
132
+ break
133
+ return rep
134
+
135
+ def __len__(self):
136
+ return len(self.images)
137
+
138
+ def __getitem__(self, idx):
139
+
140
+ img = load_image(self.images[idx])
141
+ mask = load_image(self.masks[idx])
142
+
143
+ img = img / (2**self.bits - 1)
144
+ mask = (mask > 0).astype(np.float32)
145
+
146
+ if self.transform is not None:
147
+ img = self.transform(img)
148
+
149
+ return img, mask
150
+
151
+
152
+ class MultiIntensity(Dataset):
153
+ """Wrap datasets with different intesities
154
+
155
+ Args:
156
+ datasets (list): list of datasets to wrap
157
+ """
158
+
159
+ def __init__(self, datasets):
160
+ self.dataset = datasets[0]
161
+
162
+ for d in range(1, len(datasets)):
163
+ self.dataset.images = self.dataset.images + datasets[d].images
164
+ self.dataset.labels = self.dataset.labels + datasets[d].labels
165
+
166
+ def __len__(self):
167
+ return len(self.dataset)
168
+
169
+ def __repr__(self):
170
+ return f"Subset [{len(self.dataset)}] of " + repr(self.dataset)
171
+
172
+ def __getitem__(self, idx):
173
+ x, y = self.dataset[idx]
174
+ if self.transform is not None:
175
+ x = self.transform(x)
176
+ return x, y
177
+
178
+
179
+ class Subset(Dataset):
180
+ """Define a subset of a dataset by only selecting given indices.
181
+
182
+ Args:
183
+ dataset (Dataset): full dataset
184
+ indices (list): subset indices
185
+ """
186
+
187
+ def __init__(self, dataset, indices=None, transform=None):
188
+ self.dataset = dataset
189
+ self.indices = indices if indices is not None else range(len(dataset))
190
+ self.transform = transform
191
+
192
+ def __len__(self):
193
+ return len(self.indices)
194
+
195
+ def __repr__(self):
196
+ return f"Subset [{len(self)}] of " + repr(self.dataset)
197
+
198
+ def __getitem__(self, idx):
199
+ x, y = self.dataset[self.indices[idx]]
200
+ if self.transform is not None:
201
+ x = self.transform(x)
202
+ return x, y
203
+
204
+
205
+ class DroneDatasetSegmentationFull(ImageFolderDatasetSegmentation):
206
+ """Dataset consisting of full-sized numpy images and masks. Images are normalized to range [0, 1].
207
+ """
208
+
209
+ black_level = [0.0625, 0.0626, 0.0625, 0.0626]
210
+ white_balance = [2.86653646, 1., 1.73079425]
211
+ colour_matrix = [1.50768983, -0.33571374, -0.17197604, -0.23048614,
212
+ 1.70698738, -0.47650126, -0.03119153, -0.32803956, 1.35923111]
213
+ camera_parameters = black_level, white_balance, colour_matrix
214
+
215
+ def __init__(self, I_ratio=1.0, transform=None, force_download=False, bits=16):
216
+
217
+ assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
218
+
219
+ img_dir = f'data/drone/images_full/raw_scale{int(I_ratio*100):03d}'
220
+ mask_dir = 'data/drone/masks_full'
221
+
222
+ download_drone_dataset(force_download) # XXX: zip files and add checksum? date?
223
+
224
+ super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=bits)
225
+
226
+
227
+ class DroneDatasetSegmentationTiled(ImageFolderDatasetSegmentation):
228
+ """Dataset consisting of tiled numpy images and masks. Images are in range [0, 1]
229
+ Args:
230
+ tile_size (int, optional): size of the tiled images. Defaults to 256.
231
+ """
232
+
233
+ camera_parameters = DroneDatasetSegmentationFull.camera_parameters
234
+
235
+ def __init__(self, I_ratio=1.0, transform=None):
236
+
237
+ tile_size = 256
238
+
239
+ img_dir = f'data/drone/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}'
240
+ mask_dir = f'data/drone/masks_tiles_{tile_size}'
241
+
242
+ if not os.path.exists(img_dir) or not os.path.exists(mask_dir):
243
+ dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1)
244
+ print("tiling dataset..")
245
+ create_tiles_dataset(dataset_full, img_dir, mask_dir, tile_size=tile_size)
246
+
247
+ super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=16)
248
+
249
+
250
+ class DroneDatasetClassificationTiled(ImageFolderDataset):
251
+
252
+ camera_parameters = DroneDatasetSegmentationFull.camera_parameters
253
+
254
+ def __init__(self, I_ratio=1.0, transform=None):
255
+
256
+ random_state = 72
257
+ tile_size = 256
258
+ thr = 0.01
259
+
260
+ img_dir = f'data/drone/classification/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}_thr_{thr}'
261
+ mask_dir = f'data/drone/classification/masks_tiles_{tile_size}_thr_{thr}'
262
+ df_path = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv'
263
+
264
+ if not os.path.exists(img_dir) or not os.path.exists(mask_dir):
265
+ dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1)
266
+ print("tiling dataset..")
267
+ create_tiles_dataset_binary(dataset_full, img_dir, mask_dir, random_state, thr, tile_size=tile_size)
268
+
269
+ self.classes = ['car', 'no car']
270
+ self.df = pd.read_csv(df_path)
271
+ labels = self.df['label'].to_list()
272
+
273
+ super().__init__(img_dir=img_dir, labels=labels, transform=transform, bits=16)
274
+
275
+ images, class_labels = read_label_csv(self.df)
276
+ self.images = [os.path.join(self.img_dir, image) for image in images]
277
+ self.labels = class_labels
278
+
279
+
280
+ class MicroscopyDataset(ImageFolderDataset):
281
+ """MicroscopyDataset raw images
282
+
283
+ Args:
284
+ I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0]
285
+ raw (bool): Select rgb dataset or raw dataset
286
+ transform (callable, optional): transformation to apply to image and mask
287
+ bits (int, optional): normalize image by dividing by 2^bits - 1
288
+ """
289
+
290
+ black_level = [9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06]
291
+ white_balance = [-0.6567, 1.9673, 3.5304]
292
+ colour_matrix = [-2.0338, 0.0933, 0.4157, -0.0286, 2.6464, -0.0574, -0.5516, -0.0947, 2.9308]
293
+
294
+ camera_parameters = black_level, white_balance, colour_matrix
295
+
296
+ dataset_mean = [0.91, 0.84, 0.94]
297
+ dataset_std = [0.08, 0.12, 0.05]
298
+
299
+ def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False):
300
+
301
+ assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
302
+
303
+ download_microscopy_dataset(force_download=force_download)
304
+
305
+ self.img_dir = f'data/microscopy/images/raw_scale{int(I_ratio*100):03d}'
306
+ self.transform = transform
307
+ self.bits = bits
308
+
309
+ self.label_file = 'data/microscopy/labels/Ma190c_annotations.dat'
310
+
311
+ self.valid_classes = ['BAS', 'EBO', 'EOS', 'KSC', 'LYA', 'LYT', 'MMZ', 'MOB',
312
+ 'MON', 'MYB', 'MYO', 'NGB', 'NGS', 'PMB', 'PMO', 'UNC']
313
+
314
+ self.invalid_files = ['Ma190c_lame3_zone13_composite_Mcropped_2.tiff', ]
315
+
316
+ images, class_labels = read_label_file(self.label_file)
317
+
318
+ # filter classes with low appearance
319
+ self.valid_classes = [class_label for class_label in self.valid_classes
320
+ if class_labels.count(class_label) > 4]
321
+
322
+ # remove invalid classes and invalid files from (images, class_labels)
323
+ images, class_labels = list(zip(*[
324
+ (image, class_label)
325
+ for image, class_label in zip(images, class_labels)
326
+ if class_label in self.valid_classes and image not in self.invalid_files
327
+ ]))
328
+
329
+ self.classes = list(sorted({*class_labels}))
330
+
331
+ # store full path
332
+ self.images = [os.path.join(self.img_dir, image) for image in images]
333
+
334
+ # reindex labels
335
+ self.labels = [self.classes.index(class_label) for class_label in class_labels]
336
+
337
+
338
+ class MicroscopyDatasetRGB(MicroscopyDataset):
339
+ """MicroscopyDataset RGB images
340
+
341
+ Args:
342
+ I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0]
343
+ raw (bool): Select rgb dataset or raw dataset
344
+ transform (callable, optional): transformation to apply to image and mask
345
+ bits (int, optional): normalize image by dividing by 2^bits - 1
346
+ """
347
+ camera_parameters = None
348
+
349
+ dataset_mean = None
350
+ dataset_std = None
351
+
352
+ def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False):
353
+ super().__init__(I_ratio=I_ratio, transform=transform, bits=bits, force_download=force_download)
354
+ self.images = [image.replace('raw', 'rgb') for image in self.images] # XXX: hack
355
+
356
+
357
+ def read_label_file(label_file_path):
358
+
359
+ images = []
360
+ class_labels = []
361
+
362
+ with open(label_file_path, "rb") as data:
363
+ for line in data:
364
+ file_name, class_label = line.decode("utf-8").split()
365
+ image = file_name + '.tiff'
366
+ images.append(image)
367
+ class_labels.append(class_label)
368
+
369
+ return images, class_labels
370
+
371
+
372
+ def read_label_csv(df):
373
+
374
+ images = []
375
+ class_labels = []
376
+
377
+ for file_name, label in zip(df['file name'], df['label']):
378
+ image = file_name + '.tif'
379
+ images.append(image)
380
+ class_labels.append(int(label))
381
+ return images, class_labels
382
+
383
+
384
+ def download_drone_dataset(force_download):
385
+ b2_download_folder('drone/images', 'data/drone/images_full', force_download=force_download)
386
+ b2_download_folder('drone/masks', 'data/drone/masks_full', force_download=force_download)
387
+ unzip_drone_images()
388
+
389
+
390
+ def download_microscopy_dataset(force_download):
391
+ b2_download_folder('Data histopathology/WhiteCellsImages',
392
+ 'data/microscopy/images', force_download=force_download)
393
+ b2_download_folder('Data histopathology/WhiteCellsLabels',
394
+ 'data/microscopy/labels', force_download=force_download)
395
+ unzip_microscopy_images()
396
+
397
+
398
+ def unzip_microscopy_images():
399
+
400
+ if os.path.isfile('data/microscopy/labels/.bzEmpty'):
401
+ os.remove('data/microscopy/labels/.bzEmpty')
402
+
403
+ for file in os.listdir('data/microscopy/images'):
404
+ if file.endswith(".zip"):
405
+ zip = zipfile.ZipFile(os.path.join('data/microscopy/images', file))
406
+ zip.extractall('data/microscopy/images')
407
+ os.remove(os.path.join('data/microscopy/images', file))
408
+
409
+
410
+ def unzip_drone_images():
411
+
412
+ if os.path.isfile('data/drone/masks_full/.bzEmpty'):
413
+ os.remove('data/drone/masks_full/.bzEmpty')
414
+
415
+ for file in os.listdir('data/drone/images_full'):
416
+ if file.endswith(".zip"):
417
+ zip = zipfile.ZipFile(os.path.join('data/drone/images_full', file))
418
+ zip.extractall('data/drone/images_full')
419
+ os.remove(os.path.join('data/drone/images_full', file))
420
+
421
+
422
+ def create_tiles_dataset(dataset, img_dir, mask_dir, tile_size=256):
423
+ for folder in [img_dir, mask_dir]:
424
+ if not os.path.exists(folder):
425
+ os.makedirs(folder)
426
+ for n, (img, mask) in enumerate(dataset):
427
+ tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
428
+ tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
429
+ tiled_img, tiled_mask = class_detection(tiled_img, tiled_mask) # Remove images without cars in it
430
+ for i, (sub_img, sub_mask) in enumerate(zip(tiled_img, tiled_mask)):
431
+ tile_id = f"{n:02d}_{i:05d}"
432
+ Image.fromarray(sub_img).save(os.path.join(img_dir, tile_id + '.tif'))
433
+ Image.fromarray(sub_mask > 0).save(os.path.join(mask_dir, tile_id + '.png'))
434
+
435
+
436
+ def create_tiles_dataset_binary(dataset, img_dir, mask_dir, random_state, thr, tile_size=256):
437
+
438
+ for folder in [img_dir, mask_dir]:
439
+ if not os.path.exists(folder):
440
+ os.makedirs(folder)
441
+
442
+ ids = []
443
+ labels = []
444
+
445
+ for n, (img, mask) in enumerate(dataset):
446
+ tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
447
+ tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
448
+
449
+ X_with, X_without, Y_with, Y_without = binary_class_detection(
450
+ tiled_img, tiled_mask, random_state, thr) # creates balanced arrays with class and without class
451
+
452
+ for i, (sub_X_with, sub_Y_with) in enumerate(zip(X_with, Y_with)):
453
+ tile_id = f"{n:02d}_{i:05d}"
454
+ ids.append(tile_id)
455
+ labels.append(0)
456
+ Image.fromarray(sub_X_with).save(os.path.join(img_dir, tile_id + '.tif'))
457
+ Image.fromarray(sub_Y_with > 0).save(os.path.join(mask_dir, tile_id + '.png'))
458
+ for j, (sub_X_without, sub_Y_without) in enumerate(zip(X_without, Y_without)):
459
+ tile_id = f"{n:02d}_{i+1+j:05d}"
460
+ ids.append(tile_id)
461
+ labels.append(1)
462
+ Image.fromarray(sub_X_without).save(os.path.join(img_dir, tile_id + '.tif'))
463
+ Image.fromarray(sub_Y_without > 0).save(os.path.join(mask_dir, tile_id + '.png'))
464
+ # Image.fromarray(sub_mask).save(os.path.join(mask_dir, tile_id + '.png'))
465
+
466
+ df = pd.DataFrame({'file name': ids, 'label': labels})
467
+
468
+ df_loc = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv'
469
+ df.to_csv(df_loc)
470
+
471
+ return
472
+
473
+
474
+ def class_detection(X, Y):
475
+ """Split dataset in images which has the class in the target
476
+
477
+ Args:
478
+ X (ndarray): input image
479
+ Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class)
480
+ Returns:
481
+ X_with_class (ndarray): input regions with the selected class
482
+ Y_with_class (ndarray): target regions with the selected class
483
+ X_without_class (ndarray): input regions without the selected class
484
+ Y_without_class (ndarray): target regions without the selected class
485
+ """
486
+
487
+ with_class = []
488
+ without_class = []
489
+ for i, img in enumerate(Y):
490
+ if img.mean() == 0:
491
+ without_class.append(i)
492
+ else:
493
+ with_class.append(i)
494
+
495
+ X_with_class = np.delete(X, without_class, 0)
496
+ Y_with_class = np.delete(Y, without_class, 0)
497
+
498
+ return X_with_class, Y_with_class
499
+
500
+
501
+ def binary_class_detection(X, Y, random_seed, thr):
502
+ """Splits subimages in subimages with the selected class and without the selected class by calculating the mean of the submasks; subimages with 0 < submask.mean()<=thr are disregared
503
+
504
+
505
+
506
+ Args:
507
+ X (ndarray): input image
508
+ Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class)
509
+ thr (flaot): sub images are not considered if 0 < sub_target.mean() <= thr
510
+ balanced (bool): number of returned sub images is equal for both classes if true
511
+ random_seed (None or int): selection of sub images in class with more elements according to random_seed if balanced
512
+ Returns:
513
+ X_with_class (ndarray): input regions with the selected class
514
+ Y_with_class (ndarray): target regions with the selected class
515
+ X_without_class (ndarray): input regions without the selected class
516
+ Y_without_class (ndarray): target regions without the selected class
517
+ """
518
+
519
+ with_class = []
520
+ without_class = []
521
+ no_class = []
522
+
523
+ for i, img in enumerate(Y):
524
+ m = img.mean()
525
+ if m == 0:
526
+ without_class.append(i)
527
+ else:
528
+ if m > thr:
529
+ with_class.append(i)
530
+ else:
531
+ no_class.append(i)
532
+
533
+ N = len(with_class)
534
+ M = len(without_class)
535
+ random.seed(random_seed)
536
+ if N <= M:
537
+ random.shuffle(without_class)
538
+ with_class.extend(without_class[:M - N])
539
+ else:
540
+ random.shuffle(with_class)
541
+ without_class.extend(with_class[:N - M])
542
+
543
+ X_with_class = np.delete(X, without_class + no_class, 0)
544
+ X_without_class = np.delete(X, with_class + no_class, 0)
545
+ Y_with_class = np.delete(Y, without_class + no_class, 0)
546
+ Y_without_class = np.delete(Y, with_class + no_class, 0)
547
+
548
+ return X_with_class, X_without_class, Y_with_class, Y_without_class
549
+
550
+
551
+ def make_dataloader(dataset, batch_size, shuffle=True):
552
+
553
+ X, Y = dataset
554
+
555
+ X, Y = np2torch(X), np2torch(Y)
556
+
557
+ dataset = TensorDataset(X, Y)
558
+ dataset = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
559
+
560
+ return dataset
561
+
562
+
563
+ def check_image_folder_consistency(images, masks):
564
+ file_type_images = images[0].split('.')[-1].lower()
565
+ file_type_masks = masks[0].split('.')[-1].lower()
566
+ assert len(images) == len(masks), "images / masks length mismatch"
567
+ for img_file, mask_file in zip(images, masks):
568
+ img_name = img_file.split('/')[-1].split('.')[0]
569
+ assert img_name in mask_file, f"image {img_file} corresponds to {mask_file}?"
570
+ assert img_file.split('.')[-1].lower() == file_type_images, \
571
+ f"image file {img_file} file type mismatch. Shoule be: {file_type_images}"
572
+ assert mask_file.split('.')[-1].lower() == file_type_masks, \
573
+ f"image file {mask_file} file type mismatch. Should be: {file_type_masks}"
demo-files/car.png ADDED
demo-files/micro.png ADDED
demo.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ #import tensorflow as tf
3
+ import numpy as np
4
+ import json
5
+ from os.path import dirname, realpath, join
6
+ import processing.pipeline_numpy as ppn
7
+
8
+
9
+ # Load human-readable labels for ImageNet.
10
+ current_dir = dirname(realpath(__file__))
11
+
12
+
13
+ def process(RawImage, CameraParameters, Debayer, Sharpening, Denoising):
14
+ raw_img = RawImage
15
+ if CameraParameters == "Microscope":
16
+ black_level = [9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06]
17
+ white_balance = [-0.6567, 1.9673, 3.5304]
18
+ colour_matrix = [-2.0338, 0.0933, 0.4157, -0.0286, 2.6464, -0.0574, -0.5516, -0.0947, 2.9308]
19
+ elif CameraParameters == "Drone":
20
+ #drone
21
+ black_level = [0.0625, 0.0626, 0.0625, 0.0626]
22
+ white_balance = [2.86653646, 1., 1.73079425]
23
+ colour_matrix = [1.50768983, -0.33571374, -0.17197604, -0.23048614,
24
+ 1.70698738, -0.47650126, -0.03119153, -0.32803956, 1.35923111]
25
+ else:
26
+ print("No valid camera parameter")
27
+ debayer = Debayer
28
+ sharpening = Sharpening
29
+ denoising = Denoising
30
+ print(np.max(raw_img))
31
+ raw_img = (raw_img[:,:,0].astype(np.float64)/255.)
32
+ img = ppn.processing(raw_img, black_level, white_balance, colour_matrix,
33
+ debayer=debayer, sharpening=sharpening, denoising=denoising)
34
+ print(np.max(img))
35
+ return img
36
+
37
+
38
+ iface = gr.Interface(
39
+ process,
40
+ [gr.inputs.Image(),gr.inputs.Radio(["Microscope", "Drone"]),gr.inputs.Dropdown(["bilinear", "malvar2004", "menon2007"]),
41
+ gr.inputs.Dropdown(["sharpening_filter", "unsharp_masking"]),
42
+ gr.inputs.Dropdown(["gaussian_denoising", "median_denoising"])],
43
+ "image",
44
+ capture_session=True,
45
+ examples=[
46
+ ["demo-files/car.png"],
47
+ ["demo-files/micro.png"]
48
+ ],
49
+ title="Lens2Logit - Static processing demo",
50
+ description="You can select a sample raw image, the camera parameters and the pipeline configuration to process the raw image.",)
51
+
52
+ if __name__ == "__main__":
53
+ iface.launch(share=True)
54
+
environment.yml ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: perturbed
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _ipyw_jlab_nb_ext_conf=0.1.0=py37_0
6
+ - _libgcc_mutex=0.1=main
7
+ - alabaster=0.7.12=py37_0
8
+ - anaconda=2019.10=py37_0
9
+ - anaconda-client=1.7.2=py37_0
10
+ - anaconda-navigator=1.9.7=py37_0
11
+ - anaconda-project=0.8.3=py_0
12
+ - asn1crypto=1.0.1=py37_0
13
+ - astroid=2.3.1=py37_0
14
+ - astropy=3.2.2=py37h7b6447c_0
15
+ - atomicwrites=1.3.0=py37_1
16
+ - attrs=19.2.0=py_0
17
+ - babel=2.7.0=py_0
18
+ - backcall=0.1.0=py37_0
19
+ - backports=1.0=py_2
20
+ - backports.functools_lru_cache=1.5=py_2
21
+ - backports.os=0.1.1=py37_0
22
+ - backports.shutil_get_terminal_size=1.0.0=py37_2
23
+ - backports.tempfile=1.0=py_1
24
+ - backports.weakref=1.0.post1=py_1
25
+ - beautifulsoup4=4.8.0=py37_0
26
+ - bitarray=1.0.1=py37h7b6447c_0
27
+ - bkcharts=0.2=py37_0
28
+ - blas=1.0=mkl
29
+ - bleach=3.1.0=py37_0
30
+ - blosc=1.16.3=hd408876_0
31
+ - bokeh=1.3.4=py37_0
32
+ - boto=2.49.0=py37_0
33
+ - bottleneck=1.2.1=py37h035aef0_1
34
+ - bzip2=1.0.8=h7b6447c_0
35
+ - ca-certificates=2019.8.28=0
36
+ - cairo=1.14.12=h8948797_3
37
+ - certifi=2019.9.11=py37_0
38
+ - cffi=1.12.3=py37h2e261b9_0
39
+ - chardet=3.0.4=py37_1003
40
+ - click=7.0=py37_0
41
+ - cloudpickle=1.2.2=py_0
42
+ - clyent=1.2.2=py37_1
43
+ - colorama=0.4.1=py37_0
44
+ - conda-package-handling=1.6.0=py37h7b6447c_0
45
+ - conda-verify=3.4.2=py_1
46
+ - contextlib2=0.6.0=py_0
47
+ - cryptography=2.7=py37h1ba5d50_0
48
+ - curl=7.65.3=hbc83047_0
49
+ - cycler=0.10.0=py37_0
50
+ - cython=0.29.13=py37he6710b0_0
51
+ - cytoolz=0.10.0=py37h7b6447c_0
52
+ - dask=2.5.2=py_0
53
+ - dask-core=2.5.2=py_0
54
+ - dbus=1.13.6=h746ee38_0
55
+ - decorator=4.4.0=py37_1
56
+ - defusedxml=0.6.0=py_0
57
+ - distributed=2.5.2=py_0
58
+ - docutils=0.15.2=py37_0
59
+ - entrypoints=0.3=py37_0
60
+ - et_xmlfile=1.0.1=py37_0
61
+ - expat=2.2.6=he6710b0_0
62
+ - fastcache=1.1.0=py37h7b6447c_0
63
+ - filelock=3.0.12=py_0
64
+ - flask=1.1.1=py_0
65
+ - fontconfig=2.13.0=h9420a91_0
66
+ - freetype=2.9.1=h8a8886c_1
67
+ - fribidi=1.0.5=h7b6447c_0
68
+ - future=0.17.1=py37_0
69
+ - get_terminal_size=1.0.0=haa9412d_0
70
+ - gevent=1.4.0=py37h7b6447c_0
71
+ - glib=2.56.2=hd408876_0
72
+ - glob2=0.7=py_0
73
+ - gmp=6.1.2=h6c8ec71_1
74
+ - gmpy2=2.0.8=py37h10f8cd9_2
75
+ - graphite2=1.3.13=h23475e2_0
76
+ - greenlet=0.4.15=py37h7b6447c_0
77
+ - gst-plugins-base=1.14.0=hbbd80ab_1
78
+ - gstreamer=1.14.0=hb453b48_1
79
+ - h5py=2.9.0=py37h7918eee_0
80
+ - harfbuzz=1.8.8=hffaf4a1_0
81
+ - hdf5=1.10.4=hb1b8bf9_0
82
+ - heapdict=1.0.1=py_0
83
+ - html5lib=1.0.1=py37_0
84
+ - icu=58.2=h9c2bf20_1
85
+ - idna=2.8=py37_0
86
+ - imageio=2.6.0=py37_0
87
+ - imagesize=1.1.0=py37_0
88
+ - intel-openmp=2019.4=243
89
+ - ipykernel=5.1.2=py37h39e3cac_0
90
+ - ipython=7.8.0=py37h39e3cac_0
91
+ - ipython_genutils=0.2.0=py37_0
92
+ - ipywidgets=7.5.1=py_0
93
+ - isort=4.3.21=py37_0
94
+ - itsdangerous=1.1.0=py37_0
95
+ - jbig=2.1=hdba287a_0
96
+ - jdcal=1.4.1=py_0
97
+ - jedi=0.15.1=py37_0
98
+ - jeepney=0.4.1=py_0
99
+ - jinja2=2.10.3=py_0
100
+ - joblib=0.13.2=py37_0
101
+ - jpeg=9b=h024ee3a_2
102
+ - json5=0.8.5=py_0
103
+ - jsonschema=3.0.2=py37_0
104
+ - jupyter=1.0.0=py37_7
105
+ - jupyter_client=5.3.3=py37_1
106
+ - jupyter_console=6.0.0=py37_0
107
+ - jupyter_core=4.5.0=py_0
108
+ - jupyterlab=1.1.4=pyhf63ae98_0
109
+ - jupyterlab_server=1.0.6=py_0
110
+ - keyring=18.0.0=py37_0
111
+ - kiwisolver=1.1.0=py37he6710b0_0
112
+ - krb5=1.16.1=h173b8e3_7
113
+ - lazy-object-proxy=1.4.2=py37h7b6447c_0
114
+ - libarchive=3.3.3=h5d8350f_5
115
+ - libcurl=7.65.3=h20c2e04_0
116
+ - libedit=3.1.20181209=hc058e9b_0
117
+ - libffi=3.2.1=hd88cf55_4
118
+ - libgcc-ng=9.1.0=hdf63c60_0
119
+ - libgfortran-ng=7.3.0=hdf63c60_0
120
+ - liblief=0.9.0=h7725739_2
121
+ - libpng=1.6.37=hbc83047_0
122
+ - libsodium=1.0.16=h1bed415_0
123
+ - libssh2=1.8.2=h1ba5d50_0
124
+ - libstdcxx-ng=9.1.0=hdf63c60_0
125
+ - libtiff=4.0.10=h2733197_2
126
+ - libtool=2.4.6=h7b6447c_5
127
+ - libuuid=1.0.3=h1bed415_2
128
+ - libxcb=1.13=h1bed415_1
129
+ - libxml2=2.9.9=hea5a465_1
130
+ - libxslt=1.1.33=h7d1a2b0_0
131
+ - llvmlite=0.29.0=py37hd408876_0
132
+ - locket=0.2.0=py37_1
133
+ - lxml=4.4.1=py37hefd8a0e_0
134
+ - lz4-c=1.8.1.2=h14c3975_0
135
+ - lzo=2.10=h49e0be7_2
136
+ - markupsafe=1.1.1=py37h7b6447c_0
137
+ - matplotlib=3.1.1=py37h5429711_0
138
+ - mccabe=0.6.1=py37_1
139
+ - mistune=0.8.4=py37h7b6447c_0
140
+ - mkl=2019.4=243
141
+ - mkl-service=2.3.0=py37he904b0f_0
142
+ - mkl_fft=1.0.14=py37ha843d7b_0
143
+ - mkl_random=1.1.0=py37hd6b4f25_0
144
+ - mock=3.0.5=py37_0
145
+ - more-itertools=7.2.0=py37_0
146
+ - mpc=1.1.0=h10f8cd9_1
147
+ - mpfr=4.0.1=hdf1c602_3
148
+ - mpmath=1.1.0=py37_0
149
+ - msgpack-python=0.6.1=py37hfd86e86_1
150
+ - multipledispatch=0.6.0=py37_0
151
+ - navigator-updater=0.2.1=py37_0
152
+ - nbconvert=5.6.0=py37_1
153
+ - nbformat=4.4.0=py37_0
154
+ - ncurses=6.1=he6710b0_1
155
+ - networkx=2.3=py_0
156
+ - nltk=3.4.5=py37_0
157
+ - nose=1.3.7=py37_2
158
+ - notebook=6.0.1=py37_0
159
+ - numba=0.45.1=py37h962f231_0
160
+ - numexpr=2.7.0=py37h9e4a6bb_0
161
+ - numpy=1.17.2=py37haad9e8e_0
162
+ - numpy-base=1.17.2=py37hde5b4d6_0
163
+ - numpydoc=0.9.1=py_0
164
+ - olefile=0.46=py37_0
165
+ - openpyxl=3.0.0=py_0
166
+ - openssl=1.1.1d=h7b6447c_2
167
+ - packaging=19.2=py_0
168
+ - pandoc=2.2.3.2=0
169
+ - pandocfilters=1.4.2=py37_1
170
+ - pango=1.42.4=h049681c_0
171
+ - parso=0.5.1=py_0
172
+ - partd=1.0.0=py_0
173
+ - patchelf=0.9=he6710b0_3
174
+ - path.py=12.0.1=py_0
175
+ - pathlib2=2.3.5=py37_0
176
+ - patsy=0.5.1=py37_0
177
+ - pcre=8.43=he6710b0_0
178
+ - pep8=1.7.1=py37_0
179
+ - pexpect=4.7.0=py37_0
180
+ - pickleshare=0.7.5=py37_0
181
+ - pip=19.2.3=py37_0
182
+ - pixman=0.38.0=h7b6447c_0
183
+ - pkginfo=1.5.0.1=py37_0
184
+ - pluggy=0.13.0=py37_0
185
+ - ply=3.11=py37_0
186
+ - prometheus_client=0.7.1=py_0
187
+ - prompt_toolkit=2.0.10=py_0
188
+ - psutil=5.6.3=py37h7b6447c_0
189
+ - ptyprocess=0.6.0=py37_0
190
+ - py=1.8.0=py37_0
191
+ - py-lief=0.9.0=py37h7725739_2
192
+ - pycodestyle=2.5.0=py37_0
193
+ - pycosat=0.6.3=py37h14c3975_0
194
+ - pycparser=2.19=py37_0
195
+ - pycrypto=2.6.1=py37h14c3975_9
196
+ - pycurl=7.43.0.3=py37h1ba5d50_0
197
+ - pyflakes=2.1.1=py37_0
198
+ - pygments=2.4.2=py_0
199
+ - pylint=2.4.2=py37_0
200
+ - pyodbc=4.0.27=py37he6710b0_0
201
+ - pyopenssl=19.0.0=py37_0
202
+ - pyparsing=2.4.2=py_0
203
+ - pyqt=5.9.2=py37h05f1152_2
204
+ - pyrsistent=0.15.4=py37h7b6447c_0
205
+ - pysocks=1.7.1=py37_0
206
+ - pytables=3.5.2=py37h71ec239_1
207
+ - pytest=5.2.1=py37_0
208
+ - pytest-arraydiff=0.3=py37h39e3cac_0
209
+ - pytest-astropy=0.5.0=py37_0
210
+ - pytest-doctestplus=0.4.0=py_0
211
+ - pytest-openfiles=0.4.0=py_0
212
+ - pytest-remotedata=0.3.2=py37_0
213
+ - python=3.7.4=h265db76_1
214
+ - python-dateutil=2.8.0=py37_0
215
+ - python-libarchive-c=2.8=py37_13
216
+ - pytz=2019.3=py_0
217
+ - pyyaml=5.1.2=py37h7b6447c_0
218
+ - pyzmq=18.1.0=py37he6710b0_0
219
+ - qt=5.9.7=h5867ecd_1
220
+ - qtawesome=0.6.0=py_0
221
+ - qtconsole=4.5.5=py_0
222
+ - qtpy=1.9.0=py_0
223
+ - readline=7.0=h7b6447c_5
224
+ - requests=2.22.0=py37_0
225
+ - ripgrep=0.10.0=hc07d326_0
226
+ - rope=0.14.0=py_0
227
+ - ruamel_yaml=0.15.46=py37h14c3975_0
228
+ - scikit-learn=0.21.3=py37hd81dba3_0
229
+ - scipy=1.3.1=py37h7c811a0_0
230
+ - seaborn=0.9.0=py37_0
231
+ - secretstorage=3.1.1=py37_0
232
+ - send2trash=1.5.0=py37_0
233
+ - setuptools=41.4.0=py37_0
234
+ - simplegeneric=0.8.1=py37_2
235
+ - singledispatch=3.4.0.3=py37_0
236
+ - sip=4.19.8=py37hf484d3e_0
237
+ - six=1.12.0=py37_0
238
+ - snappy=1.1.7=hbae5bb6_3
239
+ - snowballstemmer=2.0.0=py_0
240
+ - sortedcollections=1.1.2=py37_0
241
+ - sortedcontainers=2.1.0=py37_0
242
+ - soupsieve=1.9.3=py37_0
243
+ - sphinx=2.2.0=py_0
244
+ - sphinxcontrib=1.0=py37_1
245
+ - sphinxcontrib-applehelp=1.0.1=py_0
246
+ - sphinxcontrib-devhelp=1.0.1=py_0
247
+ - sphinxcontrib-htmlhelp=1.0.2=py_0
248
+ - sphinxcontrib-jsmath=1.0.1=py_0
249
+ - sphinxcontrib-qthelp=1.0.2=py_0
250
+ - sphinxcontrib-serializinghtml=1.1.3=py_0
251
+ - sphinxcontrib-websupport=1.1.2=py_0
252
+ - spyder=3.3.6=py37_0
253
+ - spyder-kernels=0.5.2=py37_0
254
+ - sqlalchemy=1.3.9=py37h7b6447c_0
255
+ - sqlite=3.30.0=h7b6447c_0
256
+ - statsmodels=0.10.1=py37hdd07704_0
257
+ - sympy=1.4=py37_0
258
+ - tbb=2019.4=hfd86e86_0
259
+ - tblib=1.4.0=py_0
260
+ - terminado=0.8.2=py37_0
261
+ - testpath=0.4.2=py37_0
262
+ - tk=8.6.8=hbc83047_0
263
+ - toolz=0.10.0=py_0
264
+ - tornado=6.0.3=py37h7b6447c_0
265
+ - traitlets=4.3.3=py37_0
266
+ - unicodecsv=0.14.1=py37_0
267
+ - unixodbc=2.3.7=h14c3975_0
268
+ - wcwidth=0.1.7=py37_0
269
+ - webencodings=0.5.1=py37_1
270
+ - werkzeug=0.16.0=py_0
271
+ - wheel=0.33.6=py37_0
272
+ - widgetsnbextension=3.5.1=py37_0
273
+ - wrapt=1.11.2=py37h7b6447c_0
274
+ - wurlitzer=1.0.3=py37_0
275
+ - xlrd=1.2.0=py37_0
276
+ - xlsxwriter=1.2.1=py_0
277
+ - xlwt=1.3.0=py37_0
278
+ - xz=5.2.4=h14c3975_4
279
+ - yaml=0.1.7=had09818_2
280
+ - zeromq=4.3.1=he6710b0_3
281
+ - zict=1.0.0=py_0
282
+ - zipp=0.6.0=py_0
283
+ - zlib=1.2.11=h7b6447c_3
284
+ - zstd=1.3.7=h0b5b093_0
285
+ - pip:
286
+ - absl-py==0.12.0
287
+ - aiohttp==3.7.4.post0
288
+ - albumentations==0.5.2
289
+ - alembic==1.4.1
290
+ - arrow==0.17.0
291
+ - async-timeout==3.0.1
292
+ - b2sdk==1.4.0
293
+ - boto3==1.17.36
294
+ - botocore==1.20.36
295
+ - cachetools==4.2.1
296
+ - colour-demosaicing==0.1.6
297
+ - colour-science==0.3.16
298
+ - configparser==5.0.0
299
+ - databricks-cli==0.10.0
300
+ - docker==4.2.0
301
+ - docopt==0.6.2
302
+ - efficientnet-pytorch==0.6.3
303
+ - fsspec==0.8.7
304
+ - funcsigs==1.0.2
305
+ - gitdb==4.0.4
306
+ - gitpython==3.1.1
307
+ - google-auth==1.28.0
308
+ - google-auth-oauthlib==0.4.3
309
+ - gorilla==0.3.0
310
+ - grpcio==1.36.1
311
+ - gunicorn==20.0.4
312
+ - imgaug==0.4.0
313
+ - importlib-metadata==3.7.3
314
+ - jmespath==0.10.0
315
+ - logfury==0.1.2
316
+ - mako==1.1.2
317
+ - markdown==3.3.4
318
+ - mlflow==1.14.1
319
+ - multidict==5.1.0
320
+ - munch==2.5.0
321
+ - oauthlib==3.1.0
322
+ - opencv-python==4.5.1.48
323
+ - opencv-python-headless==4.5.1.48
324
+ - pandas==1.2.3
325
+ - pillow==8.1.2
326
+ - pipreqs==0.4.10
327
+ - plotly==4.14.3
328
+ - pretrainedmodels==0.7.4
329
+ - prettytable==2.1.0
330
+ - prometheus-flask-exporter==0.13.0
331
+ - protobuf==3.11.3
332
+ - pyasn1==0.4.8
333
+ - pyasn1-modules==0.2.8
334
+ - python-editor==1.0.4
335
+ - pytorch-lightning==1.2.5
336
+ - pywavelets==1.1.1
337
+ - querystring-parser==1.2.4
338
+ - rawpy==0.16.0
339
+ - requests-oauthlib==1.3.0
340
+ - retrying==1.3.3
341
+ - rsa==4.7.2
342
+ - s3transfer==0.3.6
343
+ - scikit-image==0.18.1
344
+ - segmentation-models-pytorch==0.1.3
345
+ - shapely==1.7.1
346
+ - simplejson==3.17.0
347
+ - smmap==3.0.2
348
+ - sqlparse==0.3.1
349
+ - tabulate==0.8.7
350
+ - tensorboard==2.4.1
351
+ - tensorboard-plugin-wit==1.8.0
352
+ - tifffile==2021.3.17
353
+ - timm==0.3.2
354
+ - torch==1.8.0
355
+ - torchmetrics==0.2.0
356
+ - torchvision==0.9.0
357
+ - tqdm==4.59.0
358
+ - typing-extensions==3.7.4.3
359
+ - urllib3==1.25.11
360
+ - websocket-client==0.57.0
361
+ - yarg==0.1.9
362
+ - yarl==1.6.3
363
+ prefix: /home/nobis/anaconda3/envs/perturbed
figures/ABtesting.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import json
4
+ from cv2 import transform
5
+
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from torchvision.transforms import Compose, Normalize
9
+ import torch.nn.functional as F
10
+
11
+ from dataset import get_dataset, Subset
12
+ from utils.base import get_mlflow_model_by_name, SmartFormatter
13
+ from processing.pipeline_numpy import RawProcessingPipeline
14
+
15
+ from utils.hendrycks_robustness import Distortions
16
+
17
+ import segmentation_models_pytorch as smp
18
+
19
+ import matplotlib.pyplot as plt
20
+
21
+ parser = argparse.ArgumentParser(description="AB testing, Show Results", formatter_class=SmartFormatter)
22
+
23
+ # Select experiment
24
+ parser.add_argument("--mode", type=str, default="ABShowImages", choices=('ABMakeTable', 'ABShowTable', 'ABShowImages', 'ABShowAllImages', 'CMakeTable', 'CShowTable', 'CShowImages', 'CShowAllImages'),
25
+ help='R|Choose operation to compute. \n'
26
+ 'A) Lens2Logit image generation: \n '
27
+ 'ABMakeTable: Compute cross-validation metrics results \n '
28
+ 'ABShowTable: Plot cross-validation results on a table \n '
29
+ 'ABShowImages: Choose a training and testing image to compare different pipelines \n '
30
+ 'ABShowAllImages: Plot all possible pipelines \n'
31
+ 'B) Hendrycks Perturbations, C-type dataset: \n '
32
+ 'CMakeTable: For each pipeline, it computes cross-validation metrics for different perturbations \n '
33
+ 'CShowTable: Plot metrics for different pipelines and perturbations \n '
34
+ 'CShowImages: Plot an image with a selected a pipeline and perturbation\n '
35
+ 'CShowAllImages: Plot all possible perturbations for a fixed pipeline')
36
+
37
+ parser.add_argument("--dataset_name", type=str, default='Microscopy',
38
+ choices=['Microscopy', 'Drone', 'DroneSegmentation'], help='Choose dataset')
39
+ parser.add_argument("--augmentation", type=str, default='weak',
40
+ choices=['none', 'weak', 'strong'], help='Choose augmentation')
41
+ parser.add_argument("--N_runs", type=int, default=5, help='Number of k-fold splitting used in the training')
42
+ parser.add_argument("--download_model", default=False, action='store_true', help='Download Models in cache')
43
+
44
+ # Select pipelines
45
+ parser.add_argument("--dm_train", type=str, default='bilinear', choices=('bilinear', 'malvar2004',
46
+ 'menon2007'), help='Choose demosaicing for training processing model')
47
+ parser.add_argument("--s_train", type=str, default='sharpening_filter', choices=('sharpening_filter',
48
+ 'unsharp_masking'), help='Choose sharpening for training processing model')
49
+ parser.add_argument("--dn_train", type=str, default='gaussian_denoising', choices=('gaussian_denoising',
50
+ 'median_denoising'), help='Choose denoising for training processing model')
51
+ parser.add_argument("--dm_test", type=str, default='bilinear', choices=('bilinear', 'malvar2004',
52
+ 'menon2007'), help='Choose demosaicing for testing processing model')
53
+ parser.add_argument("--s_test", type=str, default='sharpening_filter', choices=('sharpening_filter',
54
+ 'unsharp_masking'), help='Choose sharpening for testing processing model')
55
+ parser.add_argument("--dn_test", type=str, default='gaussian_denoising', choices=('gaussian_denoising',
56
+ 'median_denoising'), help='Choose denoising for testing processing model')
57
+
58
+ # Select Ctest parameters
59
+ parser.add_argument("--transform", type=str, default='identity', choices=('identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
60
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform'), help='Choose transformation to show for Ctesting')
61
+ parser.add_argument("--severity", type=int, default=1, choices=(1, 2, 3, 4, 5), help='Choose severity for Ctesting')
62
+
63
+ args = parser.parse_args()
64
+
65
+
66
+ class metrics:
67
+ def __init__(self, confusion_matrix):
68
+ self.cm = confusion_matrix
69
+ self.N_classes = len(confusion_matrix)
70
+
71
+ def accuracy(self):
72
+ Tp = torch.diagonal(self.cm, 0).sum()
73
+ N_elements = torch.sum(self.cm)
74
+ return Tp / N_elements
75
+
76
+ def precision(self):
77
+ Tp_Fp = torch.sum(self.cm, 1)
78
+ Tp_Fp[Tp_Fp == 0] = 1
79
+ return torch.diagonal(self.cm, 0) / Tp_Fp
80
+
81
+ def recall(self):
82
+ Tp_Fn = torch.sum(self.cm, 0)
83
+ Tp_Fn[Tp_Fn == 0] = 1
84
+ return torch.diagonal(self.cm, 0) / Tp_Fn
85
+
86
+ def f1_score(self):
87
+ prod = (self.precision() * self.recall())
88
+ sum = (self.precision() + self.recall())
89
+ sum[sum == 0.] = 1.
90
+ return 2 * (prod / sum)
91
+
92
+ def over_N_runs(ms, N_runs):
93
+ m, m2 = 0, 0
94
+
95
+ for i in ms:
96
+ m += i
97
+ mu = m / N_runs
98
+
99
+ for i in ms:
100
+ m2 += (i - mu)**2
101
+
102
+ sigma = torch.sqrt(m2 / (N_runs - 1))
103
+
104
+ return mu.tolist(), sigma.tolist()
105
+
106
+
107
+ class ABtesting:
108
+ def __init__(self,
109
+ dataset_name: str,
110
+ augmentation: str,
111
+ dm_train: str,
112
+ s_train: str,
113
+ dn_train: str,
114
+ dm_test: str,
115
+ s_test: str,
116
+ dn_test: str,
117
+ N_runs: int,
118
+ severity=1,
119
+ transform='identity',
120
+ download_model=False):
121
+ self.experiment_name = 'ABtesting'
122
+ self.dataset_name = dataset_name
123
+ self.augmentation = augmentation
124
+ self.dm_train = dm_train
125
+ self.s_train = s_train
126
+ self.dn_train = dn_train
127
+ self.dm_test = dm_test
128
+ self.s_test = s_test
129
+ self.dn_test = dn_test
130
+ self.N_runs = N_runs
131
+ self.severity = severity
132
+ self.transform = transform
133
+ self.download_model = download_model
134
+
135
+ def static_pip_val(self, debayer=None, sharpening=None, denoising=None, severity=None, transform=None, plot_mode=False):
136
+
137
+ if debayer == None:
138
+ debayer = self.dm_test
139
+ if sharpening == None:
140
+ sharpening = self.s_test
141
+ if denoising == None:
142
+ denoising = self.dn_test
143
+ if severity == None:
144
+ severity = self.severity
145
+ if transform == None:
146
+ transform = self.transform
147
+
148
+ dataset = get_dataset(self.dataset_name)
149
+
150
+ if self.dataset_name == "Drone" or self.dataset_name == "DroneSegmentation":
151
+ mean = torch.tensor([0.35, 0.36, 0.35])
152
+ std = torch.tensor([0.12, 0.11, 0.12])
153
+ elif self.dataset_name == "Microscopy":
154
+ mean = torch.tensor([0.91, 0.84, 0.94])
155
+ std = torch.tensor([0.08, 0.12, 0.05])
156
+
157
+ if not plot_mode:
158
+ dataset.transform = Compose([RawProcessingPipeline(
159
+ camera_parameters=dataset.camera_parameters,
160
+ debayer=debayer,
161
+ sharpening=sharpening,
162
+ denoising=denoising,
163
+ ), Distortions(severity=severity, transform=transform),
164
+ Normalize(mean, std)])
165
+ else:
166
+ dataset.transform = Compose([RawProcessingPipeline(
167
+ camera_parameters=dataset.camera_parameters,
168
+ debayer=debayer,
169
+ sharpening=sharpening,
170
+ denoising=denoising,
171
+ ), Distortions(severity=severity, transform=transform)])
172
+
173
+ return dataset
174
+
175
+ def ABclassification(self):
176
+
177
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
178
+
179
+ parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
180
+
181
+ print(
182
+ f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
183
+ print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
184
+
185
+ accuracies, precisions, recalls, f1_scores = [], [], [], []
186
+
187
+ os.system('rm -r /tmp/py*')
188
+
189
+ for N_run in range(self.N_runs):
190
+
191
+ print(f"Evaluating Run {N_run}")
192
+
193
+ run_name = parent_run_name + '_' + str(N_run)
194
+
195
+ state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
196
+ download_model=self.download_model)
197
+
198
+ dataset = self.static_pip_val()
199
+ valid_set = Subset(dataset, indices=state_dict['valid_indices'])
200
+ valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
201
+
202
+ model.eval()
203
+
204
+ len_classes = len(dataset.classes)
205
+ confusion_matrix = torch.zeros((len_classes, len_classes))
206
+
207
+ for img, label in valid_loader:
208
+
209
+ prediction = model(img.to(DEVICE)).detach().cpu()
210
+ prediction = torch.argmax(prediction, dim=1)
211
+ confusion_matrix[label, prediction] += 1 # Real value rows, Declared columns
212
+
213
+ m = metrics(confusion_matrix)
214
+
215
+ accuracies.append(m.accuracy())
216
+ precisions.append(m.precision())
217
+ recalls.append(m.recall())
218
+ f1_scores.append(m.f1_score())
219
+
220
+ os.system('rm -r /tmp/t*')
221
+
222
+ accuracy = metrics.over_N_runs(accuracies, self.N_runs)
223
+ precision = metrics.over_N_runs(precisions, self.N_runs)
224
+ recall = metrics.over_N_runs(recalls, self.N_runs)
225
+ f1_score = metrics.over_N_runs(f1_scores, self.N_runs)
226
+ return dataset.classes, accuracy, precision, recall, f1_score
227
+
228
+ def ABsegmentation(self):
229
+
230
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
231
+
232
+ parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
233
+
234
+ print(
235
+ f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
236
+ print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
237
+
238
+ IoUs = []
239
+
240
+ os.system('rm -r /tmp/py*')
241
+
242
+ for N_run in range(self.N_runs):
243
+
244
+ print(f"Evaluating Run {N_run}")
245
+
246
+ run_name = parent_run_name + '_' + str(N_run)
247
+
248
+ state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
249
+ download_model=self.download_model)
250
+
251
+ dataset = self.static_pip_val()
252
+
253
+ valid_set = Subset(dataset, indices=state_dict['valid_indices'])
254
+ valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
255
+
256
+ model.eval()
257
+
258
+ IoU = 0
259
+
260
+ for img, label in valid_loader:
261
+
262
+ prediction = model(img.to(DEVICE)).detach().cpu()
263
+ prediction = F.logsigmoid(prediction).exp().squeeze()
264
+ IoU += smp.utils.metrics.IoU()(prediction, label)
265
+
266
+ IoU = IoU / len(valid_loader)
267
+ IoUs.append(IoU.item())
268
+
269
+ os.system('rm -r /tmp/t*')
270
+
271
+ IoU = metrics.over_N_runs(torch.tensor(IoUs), self.N_runs)
272
+ return IoU
273
+
274
+ def ABShowImages(self):
275
+
276
+ path = 'results/ABtesting/imgs/'
277
+ if not os.path.exists(path):
278
+ os.makedirs(path)
279
+
280
+ path = os.path.join(
281
+ path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.dm_test[:2]}{self.s_test[0]}{self.dn_test[:2]}')
282
+
283
+ if not os.path.exists(path):
284
+ os.makedirs(path)
285
+
286
+ run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}" + \
287
+ '_' + str(0)
288
+
289
+ state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=self.download_model)
290
+
291
+ model.augmentation = None
292
+
293
+ for t in ([self.dm_train, self.s_train, self.dn_train, 'train_img'],
294
+ [self.dm_test, self.s_test, self.dn_test, 'test_img']):
295
+
296
+ debayer, sharpening, denoising, img_type = t[0], t[1], t[2], t[3]
297
+
298
+ dataset = self.static_pip_val(debayer=debayer, sharpening=sharpening, denoising=denoising, plot_mode=True)
299
+ valid_set = Subset(dataset, indices=state_dict['valid_indices'])
300
+
301
+ img, _ = next(iter(valid_set))
302
+
303
+ plt.figure()
304
+ plt.imshow(img.permute(1, 2, 0))
305
+ if img_type == 'train_img':
306
+ plt.title('Train Image')
307
+ plt.savefig(os.path.join(path, f'img_train.png'))
308
+ imgA = img
309
+ else:
310
+ plt.title('Test Image')
311
+ plt.savefig(os.path.join(path, f'img_test.png'))
312
+
313
+ for c, color in enumerate(['Red', 'Green', 'Blue']):
314
+ diff = torch.abs(imgA - img)
315
+ plt.figure()
316
+ # plt.imshow(diff.permute(1,2,0))
317
+ plt.imshow(diff[c, 50:200, 50:200], cmap=f'{color}s')
318
+ plt.title(f'|Train Image - Test Image| - {color}')
319
+ plt.colorbar()
320
+ plt.savefig(os.path.join(path, f'diff_{color}.png'))
321
+ plt.figure()
322
+ diff[diff == 0.] = 1e-5
323
+ # plt.imshow(torch.log(diff.permute(1,2,0)))
324
+ plt.imshow(torch.log(diff)[c])
325
+ plt.title(f'log(|Train Image - Test Image|) - color')
326
+ plt.colorbar()
327
+ plt.savefig(os.path.join(path, f'logdiff_{color}.png'))
328
+
329
+ if self.dataset_name == 'DroneSegmentation':
330
+ plt.figure()
331
+ plt.imshow(model(img[None].cuda()).detach().cpu().squeeze())
332
+ if img_type == 'train_img':
333
+ plt.savefig(os.path.join(path, f'mask_train.png'))
334
+ else:
335
+ plt.savefig(os.path.join(path, f'mask_test.png'))
336
+
337
+ def ABShowAllImages(self):
338
+ if not os.path.exists('results/ABtesting'):
339
+ os.makedirs('results/ABtesting')
340
+
341
+ demosaicings = ['bilinear', 'malvar2004', 'menon2007']
342
+ sharpenings = ['sharpening_filter', 'unsharp_masking']
343
+ denoisings = ['median_denoising', 'gaussian_denoising']
344
+
345
+ fig = plt.figure()
346
+ columns = 4
347
+ rows = 3
348
+
349
+ i = 1
350
+
351
+ for dm in demosaicings:
352
+ for s in sharpenings:
353
+ for dn in denoisings:
354
+
355
+ dataset = self.static_pip_val(self.dm_test, self.s_test,
356
+ self.dn_test, plot_mode=True)
357
+
358
+ img, _ = dataset[0]
359
+
360
+ fig.add_subplot(rows, columns, i)
361
+ plt.imshow(img.permute(1, 2, 0))
362
+ plt.title(f'{dm}\n{s}\n{dn}', fontsize=8)
363
+ plt.xticks([])
364
+ plt.yticks([])
365
+ plt.tight_layout()
366
+
367
+ i += 1
368
+
369
+ plt.show()
370
+ plt.savefig(f'results/ABtesting/ABpipelines.png')
371
+
372
+ def CShowImages(self):
373
+
374
+ path = 'results/Ctesting/imgs/'
375
+ if not os.path.exists(path):
376
+ os.makedirs(path)
377
+
378
+ run_name = f"{self.dataset_name}_{self.dm_test}_{self.s_test}_{self.dn_test}_{self.augmentation}" + '_' + str(0)
379
+
380
+ state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=True)
381
+
382
+ model.augmentation = None
383
+
384
+ dataset = self.static_pip_val(self.dm_test, self.s_test, self.dn_test,
385
+ self.severity, self.transform, plot_mode=True)
386
+ valid_set = Subset(dataset, indices=state_dict['valid_indices'])
387
+
388
+ img, _ = next(iter(valid_set))
389
+
390
+ plt.figure()
391
+ plt.imshow(img.permute(1, 2, 0))
392
+ plt.savefig(os.path.join(
393
+ path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.transform}_sev{self.severity}'))
394
+
395
+ def CShowAllImages(self):
396
+ if not os.path.exists('results/Cimages'):
397
+ os.makedirs('results/Cimages')
398
+
399
+ transforms = ['identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
400
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
401
+
402
+ for i, t in enumerate(transforms):
403
+
404
+ fig = plt.figure(figsize=(10, 6))
405
+ columns = 5
406
+ rows = 1
407
+
408
+ for sev in range(1, 6):
409
+
410
+ dataset = self.static_pip_val(severity=sev, transform=t, plot_mode=True)
411
+
412
+ img, _ = dataset[0]
413
+
414
+ fig.add_subplot(rows, columns, sev)
415
+ plt.imshow(img.permute(1, 2, 0))
416
+ plt.title(f'Severity: {sev}')
417
+ plt.xticks([])
418
+ plt.yticks([])
419
+ plt.tight_layout()
420
+
421
+ if '_' in t:
422
+ t = t.replace('_', ' ')
423
+ t = t[0].upper() + t[1:]
424
+
425
+ fig.suptitle(f'{t}', x=0.5, y=0.8, fontsize=24)
426
+ plt.show()
427
+ plt.savefig(f'results/Cimages/{i+1}_{t.lower()}.png')
428
+
429
+
430
+ def ABMakeTable(dataset_name: str, augmentation: str,
431
+ N_runs: int, download_model: bool):
432
+
433
+ demosaicings = ['bilinear', 'malvar2004', 'menon2007']
434
+ sharpenings = ['sharpening_filter', 'unsharp_masking']
435
+ denoisings = ['median_denoising', 'gaussian_denoising']
436
+
437
+ path = 'results/ABtesting/tables'
438
+ if not os.path.exists(path):
439
+ os.makedirs(path)
440
+
441
+ runs = {}
442
+ i = 0
443
+
444
+ for dm_train in demosaicings:
445
+ for s_train in sharpenings:
446
+ for dn_train in denoisings:
447
+ for dm_test in demosaicings:
448
+ for s_test in sharpenings:
449
+ for dn_test in denoisings:
450
+ train_pip = [dm_train, s_train, dn_train]
451
+ test_pip = [dm_test, s_test, dn_test]
452
+ runs[f'run{i}'] = {
453
+ 'dataset': dataset_name,
454
+ 'augmentation': augmentation,
455
+ 'train_pip': train_pip,
456
+ 'test_pip': test_pip,
457
+ 'N_runs': N_runs
458
+ }
459
+ ABclass = ABtesting(
460
+ dataset_name=dataset_name,
461
+ augmentation=augmentation,
462
+ dm_train=dm_train,
463
+ s_train=s_train,
464
+ dn_train=dn_train,
465
+ dm_test=dm_test,
466
+ s_test=s_test,
467
+ dn_test=dn_test,
468
+ N_runs=N_runs,
469
+ download_model=download_model
470
+ )
471
+
472
+ if dataset_name == 'DroneSegmentation':
473
+ IoU = ABclass.ABsegmentation()
474
+ runs[f'run{i}']['IoU'] = IoU
475
+ else:
476
+ classes, accuracy, precision, recall, f1_score = ABclass.ABclassification()
477
+ runs[f'run{i}']['classes'] = classes
478
+ runs[f'run{i}']['accuracy'] = accuracy
479
+ runs[f'run{i}']['precision'] = precision
480
+ runs[f'run{i}']['recall'] = recall
481
+ runs[f'run{i}']['f1_score'] = f1_score
482
+
483
+ with open(os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt'), 'w') as outfile:
484
+ json.dump(runs, outfile)
485
+
486
+ i += 1
487
+
488
+
489
+ def ABShowTable(dataset_name: str, augmentation: str):
490
+
491
+ path = 'results/ABtesting/tables'
492
+ assert os.path.exists(path), 'No tables to plot'
493
+
494
+ json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
495
+
496
+ with open(json_file, 'r') as run_file:
497
+ runs = json.load(run_file)
498
+
499
+ metrics = torch.zeros((2, 12, 12))
500
+ classes = []
501
+
502
+ i, j = 0, 0
503
+
504
+ for r in range(len(runs)):
505
+
506
+ run = runs['run' + str(r)]
507
+ if dataset_name == 'DroneSegmentation':
508
+ acc = run['IoU']
509
+ else:
510
+ acc = run['accuracy']
511
+ if len(classes) < 12:
512
+ class_list = run['test_pip']
513
+ class_name = f'{class_list[0][:2]},{class_list[1][:1]},{class_list[2][:2]}'
514
+ classes.append(class_name)
515
+ mu, sigma = round(acc[0], 4), round(acc[1], 4)
516
+
517
+ metrics[0, j, i] = mu
518
+ metrics[1, j, i] = sigma
519
+
520
+ i += 1
521
+
522
+ if i == 12:
523
+ i = 0
524
+ j += 1
525
+
526
+ differences = torch.zeros_like(metrics)
527
+
528
+ diag_mu = torch.diagonal(metrics[0], 0)
529
+ diag_sigma = torch.diagonal(metrics[1], 0)
530
+
531
+ for r in range(len(metrics[0])):
532
+ differences[0, r] = diag_mu[r] - metrics[0, r]
533
+ differences[1, r] = torch.sqrt(metrics[1, r]**2 + diag_sigma[r]**2)
534
+
535
+ # Plot with scatter
536
+
537
+ for i, img in enumerate([metrics, differences]):
538
+
539
+ x, y = torch.arange(12), torch.arange(12)
540
+ x, y = torch.meshgrid(x, y)
541
+
542
+ if i == 0:
543
+ vmin = max(0.65, round(img[0].min().item(), 2))
544
+ vmax = round(img[0].max().item(), 2)
545
+ step = 0.02
546
+ elif i == 1:
547
+ vmin = round(img[0].min().item(), 2)
548
+ if augmentation == 'none':
549
+ vmax = min(0.15, round(img[0].max().item(), 2))
550
+ if augmentation == 'weak':
551
+ vmax = min(0.08, round(img[0].max().item(), 2))
552
+ if augmentation == 'strong':
553
+ vmax = min(0.05, round(img[0].max().item(), 2))
554
+ step = 0.01
555
+
556
+ vmin = int(vmin / step) * step
557
+ vmax = int(vmax / step) * step
558
+
559
+ fig = plt.figure(figsize=(10, 6.2))
560
+ ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
561
+ marker_size = 350
562
+ plt.scatter(x, y, c=torch.rot90(img[1][x, y], -1, [0, 1]), vmin=0.,
563
+ vmax=img[1].max(), cmap='viridis', s=marker_size * 2, marker='s')
564
+ ticks = torch.arange(0., img[1].max(), 0.03).tolist()
565
+ ticks = [round(tick, 2) for tick in ticks]
566
+ cba = plt.colorbar(pad=0.06)
567
+ cba.set_ticks(ticks)
568
+ cba.ax.set_yticklabels(ticks)
569
+ # cmap = plt.cm.get_cmap('tab20c').reversed()
570
+ cmap = plt.cm.get_cmap('Reds')
571
+ plt.scatter(x, y, c=torch.rot90(img[0][x, y], -1, [0, 1]), vmin=vmin,
572
+ vmax=vmax, cmap=cmap, s=marker_size, marker='s')
573
+ ticks = torch.arange(vmin, vmax, step).tolist()
574
+ ticks = [round(tick, 2) for tick in ticks]
575
+ if ticks[-1] != vmax:
576
+ ticks.append(vmax)
577
+ cbb = plt.colorbar(pad=0.06)
578
+ cbb.set_ticks(ticks)
579
+ if i == 0:
580
+ ticks[0] = f'<{str(ticks[0])}'
581
+ elif i == 1:
582
+ ticks[-1] = f'>{str(ticks[-1])}'
583
+ cbb.ax.set_yticklabels(ticks)
584
+ for x in range(12):
585
+ for y in range(12):
586
+ txt = round(torch.rot90(img[0], -1, [0, 1])[x, y].item(), 2)
587
+ if str(txt) == '-0.0':
588
+ txt = '0.00'
589
+ elif str(txt) == '0.0':
590
+ txt = '0.00'
591
+ elif len(str(txt)) == 3:
592
+ txt = str(txt) + '0'
593
+ else:
594
+ txt = str(txt)
595
+
596
+ plt.text(x - 0.25, y - 0.1, txt, color='black', fontsize='x-small')
597
+
598
+ ax.set_xticks(torch.linspace(0, 11, 12))
599
+ ax.set_xticklabels(classes)
600
+ ax.set_yticks(torch.linspace(0, 11, 12))
601
+ classes.reverse()
602
+ ax.set_yticklabels(classes)
603
+ classes.reverse()
604
+ plt.xticks(rotation=45)
605
+ plt.yticks(rotation=45)
606
+ cba.set_label('Standard Deviation')
607
+ plt.xlabel("Test pipelines")
608
+ plt.ylabel("Train pipelines")
609
+ plt.title(f'Dataset: {dataset_name}, Augmentation: {augmentation}')
610
+ if i == 0:
611
+ if dataset_name == 'DroneSegmentation':
612
+ cbb.set_label('IoU')
613
+ plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_IoU.png"))
614
+ else:
615
+ cbb.set_label('Accuracy')
616
+ plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_accuracies.png"))
617
+ elif i == 1:
618
+ if dataset_name == 'DroneSegmentation':
619
+ cbb.set_label('IoU_d-IoU')
620
+ else:
621
+ cbb.set_label('Accuracy_d - Accuracy')
622
+ plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_differences.png"))
623
+
624
+
625
+ def CMakeTable(dataset_name: str, augmentation: str, severity: int, N_runs: int, download_model: bool):
626
+
627
+ path = 'results/Ctesting/tables'
628
+ if not os.path.exists(path):
629
+ os.makedirs(path)
630
+
631
+ demosaicings = ['bilinear', 'malvar2004', 'menon2007']
632
+ sharpenings = ['sharpening_filter', 'unsharp_masking']
633
+ denoisings = ['median_denoising', 'gaussian_denoising']
634
+
635
+ transformations = ['identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
636
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
637
+
638
+ runs = {}
639
+ i = 0
640
+
641
+ for dm in demosaicings:
642
+ for s in sharpenings:
643
+ for dn in denoisings:
644
+ for t in transformations:
645
+ pip = [dm, s, dn]
646
+ runs[f'run{i}'] = {
647
+ 'dataset': dataset_name,
648
+ 'augmentation': augmentation,
649
+ 'pipeline': pip,
650
+ 'N_runs': N_runs,
651
+ 'transform': t,
652
+ 'severity': severity,
653
+ }
654
+ ABclass = ABtesting(
655
+ dataset_name=dataset_name,
656
+ augmentation=augmentation,
657
+ dm_train=dm,
658
+ s_train=s,
659
+ dn_train=dn,
660
+ dm_test=dm,
661
+ s_test=s,
662
+ dn_test=dn,
663
+ severity=severity,
664
+ transform=t,
665
+ N_runs=N_runs,
666
+ download_model=download_model
667
+ )
668
+
669
+ if dataset_name == 'DroneSegmentation':
670
+ IoU = ABclass.ABsegmentation()
671
+ runs[f'run{i}']['IoU'] = IoU
672
+ else:
673
+ classes, accuracy, precision, recall, f1_score = ABclass.ABclassification()
674
+ runs[f'run{i}']['classes'] = classes
675
+ runs[f'run{i}']['accuracy'] = accuracy
676
+ runs[f'run{i}']['precision'] = precision
677
+ runs[f'run{i}']['recall'] = recall
678
+ runs[f'run{i}']['f1_score'] = f1_score
679
+
680
+ with open(os.path.join(path, f'{dataset_name}_{augmentation}_runs.json'), 'w') as outfile:
681
+ json.dump(runs, outfile)
682
+
683
+ i += 1
684
+
685
+
686
+ def CShowTable(dataset_name, augmentation):
687
+
688
+ path = 'results/Ctesting/tables'
689
+ assert os.path.exists(path), 'No tables to plot'
690
+
691
+ json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
692
+
693
+ transforms = ['identity', 'gauss_noise', 'shot', 'impulse', 'speckle',
694
+ 'gauss_blur', 'zoom', 'contrast', 'brightness', 'saturate', 'elastic']
695
+
696
+ pip = []
697
+
698
+ demosaicings = ['bilinear', 'malvar2004', 'menon2007']
699
+ sharpenings = ['sharpening_filter', 'unsharp_masking']
700
+ denoisings = ['median_denoising', 'gaussian_denoising']
701
+
702
+ for dm in demosaicings:
703
+ for s in sharpenings:
704
+ for dn in denoisings:
705
+ pip.append(f'{dm[:2]},{s[0]},{dn[2]}')
706
+
707
+ with open(json_file, 'r') as run_file:
708
+ runs = json.load(run_file)
709
+
710
+ metrics = torch.zeros((2, len(pip), len(transforms)))
711
+
712
+ i, j = 0, 0
713
+
714
+ for r in range(len(runs)):
715
+
716
+ run = runs['run' + str(r)]
717
+ if dataset_name == 'DroneSegmentation':
718
+ acc = run['IoU']
719
+ else:
720
+ acc = run['accuracy']
721
+ mu, sigma = round(acc[0], 4), round(acc[1], 4)
722
+
723
+ metrics[0, j, i] = mu
724
+ metrics[1, j, i] = sigma
725
+
726
+ i += 1
727
+
728
+ if i == len(transforms):
729
+ i = 0
730
+ j += 1
731
+
732
+ # Plot with scatter
733
+
734
+ img = metrics
735
+
736
+ vmin = 0.
737
+ vmax = 1.
738
+
739
+ x, y = torch.arange(12), torch.arange(11)
740
+ x, y = torch.meshgrid(x, y)
741
+
742
+ fig = plt.figure(figsize=(10, 6.2))
743
+ ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
744
+ marker_size = 350
745
+ plt.scatter(x, y, c=torch.rot90(img[1][x, y], -1, [0, 1]), vmin=0.,
746
+ vmax=img[1].max(), cmap='viridis', s=marker_size * 2, marker='s')
747
+ ticks = torch.arange(0., img[1].max(), 0.03).tolist()
748
+ ticks = [round(tick, 2) for tick in ticks]
749
+ cba = plt.colorbar(pad=0.06)
750
+ cba.set_ticks(ticks)
751
+ cba.ax.set_yticklabels(ticks)
752
+ # cmap = plt.cm.get_cmap('tab20c').reversed()
753
+ cmap = plt.cm.get_cmap('Reds')
754
+ plt.scatter(x, y, c=torch.rot90(img[0][x, y], -1, [0, 1]), vmin=vmin,
755
+ vmax=vmax, cmap=cmap, s=marker_size, marker='s')
756
+ ticks = torch.arange(vmin, vmax, step).tolist()
757
+ ticks = [round(tick, 2) for tick in ticks]
758
+ if ticks[-1] != vmax:
759
+ ticks.append(vmax)
760
+ cbb = plt.colorbar(pad=0.06)
761
+ cbb.set_ticks(ticks)
762
+ if i == 0:
763
+ ticks[0] = f'<{str(ticks[0])}'
764
+ elif i == 1:
765
+ ticks[-1] = f'>{str(ticks[-1])}'
766
+ cbb.ax.set_yticklabels(ticks)
767
+ for x in range(12):
768
+ for y in range(12):
769
+ txt = round(torch.rot90(img[0], -1, [0, 1])[x, y].item(), 2)
770
+ if str(txt) == '-0.0':
771
+ txt = '0.00'
772
+ elif str(txt) == '0.0':
773
+ txt = '0.00'
774
+ elif len(str(txt)) == 3:
775
+ txt = str(txt) + '0'
776
+ else:
777
+ txt = str(txt)
778
+
779
+ plt.text(x - 0.25, y - 0.1, txt, color='black', fontsize='x-small')
780
+
781
+ ax.set_xticks(torch.linspace(0, 11, 12))
782
+ ax.set_xticklabels(transforms)
783
+ ax.set_yticks(torch.linspace(0, 11, 12))
784
+ pip.reverse()
785
+ ax.set_yticklabels(pip)
786
+ pip.reverse()
787
+ plt.xticks(rotation=45)
788
+ plt.yticks(rotation=45)
789
+ cba.set_label('Standard Deviation')
790
+ plt.xlabel("Pipelines")
791
+ plt.ylabel("Distortions")
792
+ if dataset_name == 'DroneSegmentation':
793
+ cbb.set_label('IoU')
794
+ plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_IoU.png"))
795
+ else:
796
+ cbb.set_label('Accuracy')
797
+ plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_accuracies.png"))
798
+
799
+
800
+ if __name__ == '__main__':
801
+
802
+ if args.mode == 'ABMakeTable':
803
+ ABMakeTable(args.dataset_name, args.augmentation, args.N_runs, args.download_model)
804
+ elif args.mode == 'ABShowTable':
805
+ ABShowTable(args.dataset_name, args.augmentation)
806
+ elif args.mode == 'ABShowImages':
807
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
808
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
809
+ args.dn_test, args.N_runs, download_model=args.download_model)
810
+ ABclass.ABShowImages()
811
+ elif args.mode == 'ABShowAllImages':
812
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
813
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
814
+ args.dn_test, args.N_runs, download_model=args.download_model)
815
+ ABclass.ABShowAllImages()
816
+ elif args.mode == 'CMakeTable':
817
+ CMakeTable(args.dataset_name, args.augmentation, args.severity, args.N_runs, args.download_model)
818
+ elif args.mode == 'CShowTable': # TODO test it
819
+ CShowTable(args.dataset_name, args.augmentation, args.severity)
820
+ elif args.mode == 'CShowImages':
821
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
822
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
823
+ args.dn_test, args.N_runs, args.severity, args.transform,
824
+ download_model=args.download_model)
825
+ ABclass.CShowImages()
826
+ elif args.mode == 'CShowAllImages':
827
+ ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
828
+ args.s_train, args.dn_train, args.dm_test, args.s_test,
829
+ args.dn_test, args.N_runs, args.severity, args.transform,
830
+ download_model=args.download_model)
831
+ ABclass.CShowAllImages()
figures/figure1.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python figures.py \
2
+ --experiment_name track-test \
3
+ --run_name track-all \
4
+ --representation gradients \
5
+ --step gamma_correct \
6
+ --gif_name gradient \
7
+ --output gif \
figures/figure2.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ python figures.py \
2
+ --experiment_name track-test \
3
+ --run_name track-all \
4
+ --output train_vs_val_loss \
figures/figures.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mlflow
2
+ from mlflow.tracking import MlflowClient
3
+ from mlflow.entities import ViewType
4
+ import argparse
5
+ #gif
6
+ import os
7
+ import pathlib
8
+ import shutil
9
+ import imageio
10
+ #plot
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+
14
+ # -1. parse args
15
+ parser = argparse.ArgumentParser(description="results_analysis")
16
+ parser.add_argument("--tracking_uri", type=str,
17
+ default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS')
18
+ parser.add_argument("--experiment_name", type=str, default=None,
19
+ help='Name of the experiment on the mlflow server, e.g. "processing_comparison"')
20
+ parser.add_argument("--run_name", type=str, default=None,
21
+ help='Name of the run on the mlflow server, e.g. "proc_nn"')
22
+ parser.add_argument("--representation", type=str, default=None,
23
+ choices=["processing", "gradients"], help='The representation form you want retrieve("processing" or "gradients")')
24
+ parser.add_argument("--step", type=str, default=None,
25
+ choices=["pre_debayer", "demosaic", "color_correct", "sharpening", "gaussian", "clipped", "gamma_correct", "rgb"],
26
+ help='The processing step you want to track ("pre_debayer" or "rgb")') #TODO: include predictions and ground truths
27
+ parser.add_argument("--gif_name", type=str, default=None,
28
+ help='Name of the gif that will be saved. Note: .gif will be added later by script') #TODO: option to include filepath where result should be written
29
+ #TODO: option to write results to existing run on mlflow
30
+ parser.add_argument("--local_dir", type=str, default=None,
31
+ help='Name of the local dir to be created to store mlflow data')
32
+ parser.add_argument("--cleanup", type=bool, default=True,
33
+ help='Whether to delete the local dir again after the script was run')
34
+ parser.add_argument("--output", type=str, default=None,
35
+ choices=["gif", "train_vs_val_loss"],
36
+ help='Which output to generate') #TODO: make this cleaner, atm it is confusing because each figure may need different set of args and it is not clear how to manage that
37
+ #TODO: idea -> fix the types of args for each figure which define the figure type but parametrize those things that can reasonably vary
38
+ args = parser.parse_args()
39
+
40
+ # 0. mlflow basics
41
+ mlflow.set_tracking_uri(args.tracking_uri)
42
+
43
+ # 1. specify experiment_name, run_name, representation and step
44
+ #is done via parse_args
45
+
46
+ # 2. use get_experiment_by_name to get experiment object
47
+ experiment = mlflow.get_experiment_by_name(args.experiment_name)
48
+
49
+ # 3. extract experiment_id
50
+ #experiment.experiment_id
51
+
52
+ # 4. use search_runs with experiment_id and run_name for string search query
53
+ filter_string = "tags.mlflow.runName = '{}'".format(args.run_name) #create the filter string with using the runName tag to query mlflow
54
+ runs = mlflow.search_runs(experiment.experiment_id, filter_string=filter_string) #returns a pandas data frame where each row is a run (if several exist under that name)
55
+ client = MlflowClient() #TODO: look more into the options of client
56
+
57
+ if args.output == "gif": #TODO: outsource these options to functions which are then loaded and can be called
58
+ # 5. extract run from list
59
+ #TODO: parent run and cv option for analysis
60
+ if args.local_dir:
61
+ local_dir = args.local_dir+"/artifacts"
62
+ else: #use the current working dir and make a subdir "artifacts" to store the data from mlflow
63
+ local_dir = str(pathlib.Path().resolve())+"/artifacts"
64
+ if not os.path.isdir('artifacts'):
65
+ os.mkdir(local_dir) #create the local_dir if it does not exist, yet #TODO: more advanced catching of existing files etc
66
+ dir = client.download_artifacts(runs["run_id"][0], "results", local_dir) #TODO: parametrize this number [0] so the right run is selected
67
+
68
+ # 6. get filenames in chronological sequence and write them to gif
69
+ dirs = [x[0] for x in os.walk(dir)]
70
+ dirs = sorted(dirs, key=str.lower)[1:] #sort chronologically and remove parent dir from list
71
+
72
+ with imageio.get_writer(args.gif_name+'.gif', mode='I') as writer: #https://imageio.readthedocs.io/en/stable/index.html#
73
+ for epoch in dirs: #extract the right file from each epoch
74
+ for _, _, files in os.walk(epoch): #
75
+ for name in files:
76
+ if args.representation in name and args.step in name and "png" in name:
77
+ image = imageio.imread(epoch+"/"+name)
78
+ writer.append_data(image)
79
+
80
+ # 7. cleanup the downloaded artifacts from client file system
81
+ if args.cleanup:
82
+ shutil.rmtree(local_dir) #delete the files downloaded from mlflow
83
+
84
+ elif args.output == "train_vs_val_loss":
85
+ train_loss = client.get_metric_history(runs["run_id"][0], "train_loss") #returns a list of metric entities https://www.mlflow.org/docs/latest/_modules/mlflow/entities/metric.html
86
+ val_loss = client.get_metric_history(runs["run_id"][0], "val_loss") #TODO: parametrize this number [0] so the right run is selected
87
+ train_loss = sorted(train_loss, key=lambda m: m.step) #sort the metric objects in list according to step property
88
+ val_loss = sorted(val_loss, key=lambda m: m.step)
89
+ plt.figure()
90
+ for m_train, m_val in zip(train_loss, val_loss):
91
+ plt.scatter(m_train.value, m_val.value, alpha=1/(m_train.step+1), color='blue')
92
+ plt.savefig("scatter.png") #TODO: parametrize filename
figures/train.sh ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # # Parametrized Training
4
+ # 100 epochs, frozen_processor: http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com/#/experiments/49/runs/2803f44514e34a0f87d591520706e876
5
+ # model_uri="s3://mlflow-artifacts-601883093460/49/2803f44514e34a0f87d591520706e876/artifacts/model"
6
+
7
+ # used for training current model to 100% train and 80% val accuracy
8
+ # python train.py \
9
+ # --experiment_name parametrized \
10
+ # --classifier_uri "${model_uri}" \
11
+ # --run_name par_full_kurt \
12
+ # --dataset Microscopy \
13
+ # --lr 1e-5 \
14
+ # --epochs 50 \
15
+ # --freeze_classifier \
16
+
17
+ # --freeze_processor \
18
+
19
+ # # Adversarial Training
20
+
21
+ # python train.py \
22
+ # --experiment_name adversarial \
23
+ # --run_name adv_frozen_processor \
24
+ # --classifier_uri "${model_uri}" \
25
+ # --dataset Microscopy \
26
+ # --adv_training \
27
+ # --lr 1e-3 \
28
+ # --epochs 7 \
29
+ # --freeze_classifier \
30
+ # --track_processing \
31
+ # --track_every_epoch \
32
+ # --log_model=False \
33
+ # --adv_aux_weight=0.1 \
34
+ # --adv_aux_loss "l2" \
35
+
36
+ # --adv_aux_weight=2e-5 \
37
+ # --adv_aux_weight=2e-5 \
38
+ # --adv_aux_weight=1.9e-5 \
39
+
40
+ # Cross pipeline training (Segmentation/Classification)
41
+
42
+ # Static Pipeline Script
43
+
44
+ # datasets="Microscopy Drone DroneSegmentation"
45
+ datasets="DroneSegmentation"
46
+ augmentations="weak strong none"
47
+
48
+ demosaicings="bilinear malvar2004 menon2007"
49
+ sharpenings="sharpening_filter unsharp_masking"
50
+ denoisings="median_denoising gaussian_denoising"
51
+
52
+ for augment in $augmentations
53
+ do
54
+ for data in $datasets
55
+ do
56
+ for demosaicing in $demosaicings
57
+ do
58
+ for sharpening in $sharpenings
59
+ do
60
+ for denoising in $denoisings
61
+ do
62
+
63
+ python train.py \
64
+ --experiment_name ABtesting \
65
+ --run_name "$data"_"$demosaicing"_"$sharpening"_"$denoising"_"$augment" \
66
+ --dataset "$data" \
67
+ --batch_size 4 \
68
+ --lr 1e-5 \
69
+ --epochs 100 \
70
+ --sp_debayer "$demosaicing" \
71
+ --sp_sharpening "$sharpening" \
72
+ --sp_denoising "$denoising" \
73
+ --processing_mode "static" \
74
+ --augmentation "$augment" \
75
+ --n_split 5 \
76
+
77
+ done
78
+ done
79
+ done
80
+ done
81
+ done
model.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict
3
+
4
+ import torch
5
+ import torch.optim
6
+ from torchvision.models import resnet18
7
+ from torchvision.utils import make_grid, save_image
8
+ import torch.nn.functional as F
9
+
10
+ import pytorch_lightning as pl
11
+
12
+ import mlflow.pytorch
13
+
14
+
15
+ def resnet_model(model=resnet18, pretrained=True, in_channels=3, fc_out_features=2):
16
+ resnet = model(pretrained=pretrained)
17
+ # if not pretrained: # TODO: add case for in_channels=4
18
+ # resnet.conv1 = torch.nn.Conv2d(channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
19
+ resnet.fc = torch.nn.Linear(in_features=512, out_features=fc_out_features, bias=True)
20
+ return resnet
21
+
22
+
23
+ class LitModel(pl.LightningModule):
24
+
25
+ def __init__(self,
26
+ classifier,
27
+ loss,
28
+ lr=1e-3,
29
+ weight_decay=0,
30
+ loss_aux=None,
31
+ adv_training=False,
32
+ adv_parameters='all',
33
+ metrics=None,
34
+ processor=None,
35
+ augmentation=None,
36
+ is_segmentation_task=False,
37
+ augmentation_on_eval=False,
38
+ metrics_on_training=True,
39
+ freeze_classifier=False,
40
+ freeze_processor=False,
41
+ ):
42
+ super().__init__()
43
+
44
+ self.classifier = classifier
45
+ self.processor = processor
46
+
47
+ self.lr = lr
48
+ self.weight_decay = weight_decay
49
+ self.loss_fn = loss
50
+ self.loss_aux_fn = loss_aux
51
+ self.adv_training = adv_training
52
+ self.metrics = metrics
53
+ self.augmentation = augmentation
54
+ self.is_segmentation_task = is_segmentation_task
55
+ self.augmentation_on_eval = augmentation_on_eval
56
+ self.metrics_on_training = metrics_on_training
57
+
58
+ self.freeze_classifier = freeze_classifier
59
+ self.freeze_processor = freeze_processor
60
+
61
+ self.unfreeze()
62
+ if freeze_classifier:
63
+ pl.LightningModule.freeze(self.classifier)
64
+ if freeze_processor:
65
+ pl.LightningModule.freeze(self.processor)
66
+
67
+ if adv_training and adv_parameters != 'all':
68
+ if adv_parameters != 'all':
69
+ pl.LightningModule.freeze(self.processor)
70
+ for name, p in self.processor.named_parameters():
71
+ if adv_parameters in name:
72
+ p.requires_grad = True
73
+
74
+ def forward(self, x):
75
+ x = self.processor(x)
76
+ apply_augmentation_step = self.training or self.augmentation_on_eval
77
+ if self.augmentation is not None and apply_augmentation_step:
78
+ x = self.augmentation(x, retain_state=self.is_segmentation_task)
79
+ x = self.classifier(x)
80
+ return x
81
+
82
+ def update_step(self, batch, step_name):
83
+ x, y = batch
84
+ # debug(self.processor)
85
+ # debug(self.processor.parameters())
86
+ # debug.pause()
87
+ # print('type', type(self.processor).__name__)
88
+
89
+ logits = self(x)
90
+
91
+ apply_augmentation_mask = self.is_segmentation_task and (self.training or self.augmentation_on_eval)
92
+ if self.augmentation is not None and apply_augmentation_mask:
93
+ y = self.augmentation(y, mask_transform=True).contiguous()
94
+
95
+ loss = self.loss_fn(logits, y)
96
+
97
+ if self.loss_aux_fn is not None:
98
+ loss_aux = self.loss_aux_fn(x)
99
+ loss += loss_aux
100
+
101
+ self.log(f'{step_name}_loss', loss, on_step=False, on_epoch=True)
102
+ if self.loss_aux_fn is not None:
103
+ self.log(f'{step_name}_loss_aux', loss_aux, on_step=False, on_epoch=True)
104
+
105
+ if self.is_segmentation_task:
106
+ y_hat = F.logsigmoid(logits).exp().squeeze()
107
+ else:
108
+ y_hat = torch.argmax(logits, dim=1)
109
+
110
+ if self.metrics is not None:
111
+ for metric in self.metrics:
112
+ metric_name = metric.__name__ if hasattr(metric, '__name__') else type(metric).__name__
113
+ if metric_name == 'accuracy' or not self.training or self.metrics_on_training:
114
+ m = metric(y_hat.cpu().detach(), y.cpu())
115
+ self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
116
+ prog_bar=self.training or metric_name == 'accuracy')
117
+ if metric_name == 'iou_score' or not self.training or self.metrics_on_training:
118
+ m = metric(y_hat.cpu().detach(), y.cpu())
119
+ self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
120
+ prog_bar=self.training or metric_name == 'iou_score')
121
+ elif metric_name == 'accuracy' or not self.training or self.metrics_on_training:
122
+ m = metric(y_hat.cpu().detach(), y.cpu())
123
+ self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
124
+ prog_bar=self.training or metric_name == 'accuracy')
125
+
126
+ return loss
127
+
128
+ def training_step(self, batch, batch_idx):
129
+ return self.update_step(batch, 'train')
130
+
131
+ def validation_step(self, batch, batch_idx):
132
+ return self.update_step(batch, 'val')
133
+
134
+ def test_step(self, batch, batch_idx):
135
+ return self.update_step(batch, 'test')
136
+
137
+ def train(self, mode=True):
138
+ self.training = mode
139
+
140
+ # don't update batchnorm in adversarial training
141
+ self.processor.train(mode=mode and not self.freeze_processor and not self.adv_training)
142
+ self.classifier.train(mode=mode and not self.freeze_classifier)
143
+ return self
144
+
145
+ def configure_optimizers(self):
146
+ self.optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay=self.weight_decay)
147
+ return self.optimizer
148
+
149
+ def get_progress_bar_dict(self):
150
+ items = super().get_progress_bar_dict()
151
+ items.pop('v_num')
152
+ return items
153
+
154
+
155
+ class TrackImagesCallback(pl.callbacks.base.Callback):
156
+ def __init__(self, data_loader, reference_processor=None, track_every_epoch=False, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True):
157
+ super().__init__()
158
+ self.data_loader = data_loader
159
+
160
+ self.track_every_epoch = track_every_epoch
161
+
162
+ self.track_processing = track_processing
163
+ self.track_gradients = track_gradients
164
+ self.track_predictions = track_predictions
165
+ self.save_tensors = save_tensors
166
+
167
+ self.reference_processor = reference_processor
168
+
169
+ def callback_track_images(self, model, save_loc):
170
+ track_images(model,
171
+ self.data_loader,
172
+ reference_processor=self.reference_processor,
173
+ track_processing=self.track_processing,
174
+ track_gradients=self.track_gradients,
175
+ track_predictions=self.track_predictions,
176
+ save_tensors=self.save_tensors,
177
+ save_loc=save_loc,
178
+ )
179
+
180
+ def on_fit_end(self, trainer, pl_module):
181
+ if not self.track_every_epoch:
182
+ save_loc = 'results'
183
+ self.callback_track_images(trainer.model, save_loc)
184
+
185
+ def on_train_epoch_end(self, trainer, pl_module, outputs):
186
+ if self.track_every_epoch:
187
+ save_loc = f'results/epoch_{trainer.current_epoch + 1:04d}'
188
+ self.callback_track_images(trainer.model, save_loc)
189
+
190
+
191
+ from utils.debug import debug
192
+
193
+
194
+ # @debug
195
+ def log_tensor(batch, path, save_tensors=True, nrow=8):
196
+ if save_tensors:
197
+ torch.save(batch, path)
198
+ mlflow.log_artifact(path, os.path.dirname(path))
199
+
200
+ img_path = path.replace('.pt', '.png')
201
+ split = img_path.split('/')
202
+ img_path = '/'.join(split[:-1]) + '/img_' + split[-1] # insert 'img_'; make it easier to find in mlflow
203
+
204
+ grid = make_grid(batch, nrow=nrow).squeeze()
205
+ save_image(grid, img_path)
206
+ mlflow.log_artifact(img_path, os.path.dirname(path))
207
+
208
+
209
+ def track_images(model, data_loader, reference_processor=None, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True, save_loc='results'):
210
+
211
+ device = model.device
212
+ processor = model.processor
213
+ classifier = model.classifier
214
+
215
+ if not hasattr(processor, 'stages'): # 'static' or 'none' pipeline
216
+ return
217
+
218
+ os.makedirs(save_loc, exist_ok=True)
219
+
220
+ # TODO: implement track_predictions
221
+
222
+ # inputs_full = []
223
+ labels_full = []
224
+ logits_full = []
225
+ stages_full = defaultdict(list)
226
+ grads_full = defaultdict(list)
227
+ diffs_full = defaultdict(list)
228
+
229
+ track_differences = reference_processor is not None
230
+
231
+ for inputs, labels in data_loader:
232
+
233
+ inputs, labels = inputs.to(device), labels.to(device)
234
+ inputs.requires_grad = True
235
+
236
+ processed_rgb = processor(inputs)
237
+
238
+ if track_differences:
239
+ # debug(processor)
240
+ processed_rgb_ref = reference_processor(inputs)
241
+
242
+ if track_gradients or track_predictions:
243
+ logits = classifier(processed_rgb)
244
+
245
+ # NOTE: should zero grads for good measure
246
+ loss = model.loss_fn(logits, labels)
247
+ loss.backward()
248
+
249
+ if track_predictions:
250
+ labels_full.append(labels.cpu().detach())
251
+ logits_full.append(logits.cpu().detach())
252
+ # inputs_full.append(inputs.cpu().detach())
253
+
254
+ for stage, batch in processor.stages.items():
255
+ stages_full[stage].append(batch.cpu().detach())
256
+ if track_differences:
257
+ diffs_full[stage].append((reference_processor.stages[stage] - batch).cpu().detach())
258
+ if track_gradients:
259
+ grads_full[stage].append(batch.grad.cpu().detach())
260
+
261
+ with torch.no_grad():
262
+
263
+ stages = stages_full
264
+ grads = grads_full
265
+ diffs = diffs_full
266
+
267
+ if track_processing:
268
+ for stage, batch in stages.items():
269
+ stages[stage] = torch.cat(batch)
270
+
271
+ if track_differences:
272
+ for stage, batch in diffs.items():
273
+ diffs[stage] = torch.cat(batch)
274
+
275
+ if track_gradients:
276
+ for stage, batch in grads.items():
277
+ grads[stage] = torch.cat(batch)
278
+
279
+ for stage_nr, stage_name in enumerate(stages):
280
+ if track_processing:
281
+ batch = stages[stage_name]
282
+ log_tensor(batch, os.path.join(save_loc, f'processing_{stage_nr}_{stage_name}.pt'), save_tensors)
283
+
284
+ if track_differences:
285
+ batch = diffs[stage_name]
286
+ log_tensor(batch, os.path.join(save_loc, f'diffs_{stage_nr}_{stage_name}.pt'), False)
287
+
288
+ if track_gradients:
289
+ batch_grad = grads[stage_name]
290
+ batch_grad = batch_grad.abs()
291
+ batch_grad = (batch_grad - batch_grad.min()) / (batch_grad.max() - batch_grad.min())
292
+ log_tensor(batch_grad, os.path.join(
293
+ save_loc, f'gradients_{stage_nr}_{stage_name}.pt'), save_tensors)
294
+
295
+ # inputs = torch.cat(inputs_full)
296
+
297
+ if track_predictions: # and model.is_segmentation_task:
298
+ labels = torch.cat(labels_full)
299
+ logits = torch.cat(logits_full)
300
+ masks = labels.unsqueeze(1)
301
+ predictions = logits # torch.sigmoid(logits).unsqueeze(1)
302
+ #mask_vis = torch.cat((masks, predictions, masks * predictions), dim=1)
303
+ #log_tensor(mask_vis, os.path.join(save_loc, f'masks.pt'), save_tensors)
304
+ log_tensor(masks, os.path.join(save_loc, f'targets.pt'), save_tensors)
305
+ log_tensor(predictions, os.path.join(save_loc, f'preds.pt'), save_tensors)
processing/pipeline_numpy.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Raw Image Pipeline
3
+ """
4
+ __author__ = "Marco Aversa"
5
+
6
+ import numpy as np
7
+
8
+ from rawpy import * # XXX: no * imports!
9
+ from scipy import ndimage
10
+ from scipy import fftpack
11
+ from scipy.signal import convolve2d
12
+
13
+ from skimage.filters import unsharp_mask
14
+ from skimage.color import rgb2yuv, yuv2rgb, rgb2hsv, hsv2rgb
15
+ from skimage.restoration import denoise_tv_chambolle, denoise_tv_bregman, denoise_nl_means, denoise_bilateral, denoise_wavelet, estimate_sigma
16
+
17
+ import matplotlib.pyplot as plt
18
+
19
+ from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
20
+ demosaicing_CFA_Bayer_Malvar2004,
21
+ demosaicing_CFA_Bayer_Menon2007)
22
+
23
+ import torch
24
+ import numpy as np
25
+
26
+ from dataset import Subset
27
+ from torch.utils.data import DataLoader
28
+
29
+ from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
30
+ demosaicing_CFA_Bayer_Malvar2004,
31
+ demosaicing_CFA_Bayer_Menon2007)
32
+
33
+ import matplotlib.pyplot as plt
34
+
35
+
36
+ class RawProcessingPipeline(object):
37
+
38
+ """Applies the raw-processing pipeline from pipeline.py"""
39
+
40
+ def __init__(self, camera_parameters, debayer='bilinear', sharpening='unsharp_masking', denoising='gaussian'):
41
+ '''
42
+ Args:
43
+ camera_parameters (tuple): (black_level, white_balance, colour_matrix)
44
+ debayer (str): specifies the algorithm used as debayer; choose from {'bilinear','malvar2004','menon2007'}
45
+ sharpening (str): specifies the algorithm used for sharpening; choose from {'sharpening_filter','unsharp_masking'}
46
+ denoising (str): specifies the algorithm used for denoising; choose from choose from {'gaussian_denoising','median_denoising','fft_denoising'}
47
+ '''
48
+
49
+ self.camera_parameters = camera_parameters
50
+
51
+ self.debayer = debayer
52
+ self.sharpening = sharpening
53
+ self.denoising = denoising
54
+
55
+ def __call__(self, img):
56
+ """
57
+ Args:
58
+ img (ndarry of dtype float.32): image of size (H,W)
59
+ return:
60
+ img (tensor of dtype float): image of size (3,H,W)
61
+ """
62
+ black_level, white_balance, colour_matrix = self.camera_parameters
63
+ img = processing(img, black_level, white_balance, colour_matrix,
64
+ debayer=self.debayer, sharpening=self.sharpening, denoising=self.denoising)
65
+ img = img.transpose(2, 0, 1)
66
+
67
+ return torch.Tensor(img)
68
+
69
+
70
+ def processing(img, black_level, white_balance, colour_matrix, debayer="bilinear", sharpening="unsharp_masking",
71
+ sharp_radius=1.0, sharp_amount=1.0, denoising="median_filter", median_kernel_size=3,
72
+ gaussian_sigma=0.5, fft_fraction=0.3, weight_chambolle=0.01, weight_bregman=100,
73
+ sigma_bilateral=0.6, gamma=2.2, bits=16):
74
+ """Apply pipeline on a raw image
75
+
76
+ Args:
77
+ rawImg (ndarray): raw image
78
+ debayer (str): debayer algorithm
79
+ white_balance (None, ndarray): white balance array (if None it will take the default camera white balance array)
80
+ colour_matrix (None, ndarray): colour matrix (if None it will take the default camera colour matrix) - Size: 3x3
81
+ gamma (float): exponent for the non linear gamma correction.
82
+
83
+ Returns:
84
+ img (ndarray): post-processed image
85
+
86
+ """
87
+
88
+ # Remove Black Level
89
+ img = remove_blacklv(img, black_level)
90
+
91
+ # Apply demosaicing - We don't have access to these 3 functions
92
+ if debayer == "bilinear":
93
+ img = demosaicing_CFA_Bayer_bilinear(img)
94
+ if debayer == "malvar2004":
95
+ img = demosaicing_CFA_Bayer_Malvar2004(img)
96
+ if debayer == "menon2007":
97
+ img = demosaicing_CFA_Bayer_Menon2007(img)
98
+
99
+ # White Balance Correction
100
+
101
+ # Sunny images white balance array -> 2<r<2.8, g=1.0, 1.3<b<1.6
102
+ # Tungsten images white balance array -> 1.3<r<1.7, g=1.0, 2.2<b<2.8
103
+ # Shade images white balance array -> 2.4<r<3.2, g=1.0, 1.1<b<1.3
104
+
105
+ img = wb_correction(img, white_balance)
106
+
107
+ # Colour Correction
108
+ img = colour_correction(img, colour_matrix)
109
+
110
+ # Sharpening
111
+ if sharpening == "sharpening_filter": # Fixed sharpening
112
+ img = sharpening_filter(img)
113
+ if sharpening == "unsharp_masking": # Higher is radius and amount, higher is the sharpening
114
+ img = unsharp_masking(img, radius=sharp_radius, amount=sharp_amount, multichannel=True)
115
+
116
+ # Denoising
117
+ if denoising == "median_denoising":
118
+ img = median_denoising(img, size=median_kernel_size)
119
+ if denoising == "gaussian_denoising":
120
+ img = gaussian_denoising(img, sigma=gaussian_sigma)
121
+ if denoising == "fft_denoising": # fft_fraction = [0.0001,0.5]
122
+ img = fft_denoising(img, keep_fraction=fft_fraction, row_cut=False, column_cut=True)
123
+
124
+ # We don't have access to these 3 functions
125
+ if denoising == "tv_chambolle": # lower is weight, less is the denoising
126
+ img = denoise_tv_chambolle(img, weight=weight_chambolle, eps=0.0002, n_iter_max=200, multichannel=True)
127
+ if denoising == "tv_bregman": # lower is weight, more is the denoising
128
+ img = denoise_tv_bregman(img, weight=weight_bregman, max_iter=100,
129
+ eps=0.001, isotropic=True, multichannel=True)
130
+ # if denoising == "wavelet":
131
+ # img = denoise_wavelet(img.copy(), sigma=None, wavelet='db1', mode='soft', wavelet_levels=None, multichannel=True,
132
+ # convert2ycbcr=False, method='BayesShrink', rescale_sigma=True)
133
+ if denoising == "bilateral": # higher is sigma_spatial, more is the denoising
134
+ img = denoise_bilateral(img, win_size=None, sigma_color=None, sigma_spatial=sigma_bilateral,
135
+ bins=10000, mode='constant', cval=0, multichannel=True)
136
+
137
+ # Gamma Correction
138
+ img = np.clip(img, 0, 1)
139
+ img = adjust_gamma(img, gamma=gamma)
140
+
141
+ return img
142
+
143
+
144
+ def get_camera_parameters(rawpyImg):
145
+ black_level = rawpyImg.black_level_per_channel
146
+ white_balance = rawpyImg.camera_whitebalance[:3]
147
+ colour_matrix = rawpyImg.color_matrix[:, :3].flatten().tolist()
148
+
149
+ return black_level, white_balance, colour_matrix
150
+
151
+
152
+ def remove_blacklv(rawImg, black_level):
153
+ rawImg[0::2, 0::2] -= black_level[0] # R
154
+ rawImg[0::2, 1::2] -= black_level[1] # G
155
+ rawImg[1::2, 0::2] -= black_level[2] # G
156
+ rawImg[1::2, 1::2] -= black_level[3] # B
157
+
158
+ return rawImg
159
+
160
+
161
+ def wb_correction(img, white_balance):
162
+ return img * white_balance
163
+
164
+
165
+ def colour_correction(img, colour_matrix):
166
+ colour_matrix = np.array(colour_matrix).reshape(3, 3)
167
+ return np.einsum('ijk,lk->ijl', img, colour_matrix)
168
+
169
+
170
+ def unsharp_masking(img, radius=1.0, amount=1.0,
171
+ multichannel=False, preserve_range=True):
172
+
173
+ img = rgb2yuv(img)
174
+ img[:, :, 0] = unsharp_mask(img[:, :, 0], radius=radius, amount=amount,
175
+ multichannel=multichannel, preserve_range=preserve_range)
176
+ img = yuv2rgb(img)
177
+ return img
178
+
179
+
180
+ def sharpening_filter(image, iterations=1, kernel=np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])):
181
+
182
+ # https://towardsdatascience.com/image-processing-with-python-blurring-and-sharpening-for-beginners-3bcebec0583a
183
+
184
+ img_yuv = rgb2yuv(image)
185
+
186
+ for i in range(iterations):
187
+ img_yuv[:, :, 0] = convolve2d(img_yuv[:, :, 0], kernel, 'same', boundary='fill', fillvalue=0)
188
+
189
+ final_image = yuv2rgb(img_yuv)
190
+
191
+ return final_image
192
+
193
+
194
+ def median_denoising(img, size=3):
195
+
196
+ img = rgb2yuv(img)
197
+ img[:, :, 0] = ndimage.median_filter(img[:, :, 0], size)
198
+ img = yuv2rgb(img)
199
+
200
+ return img
201
+
202
+
203
+ def gaussian_denoising(img, sigma=0.5):
204
+
205
+ img = rgb2yuv(img)
206
+ img[:, :, 0] = ndimage.gaussian_filter(img[:, :, 0], sigma)
207
+ img = yuv2rgb(img)
208
+
209
+ return img
210
+
211
+
212
+ def fft_denoising(img, keep_fraction=0.3, row_cut=False, column_cut=True):
213
+ """ keep_fraction = 0.5 --> same image as input
214
+ keep_fraction --> 0 --> remove all details """
215
+ # http://scipy-lectures.org/intro/scipy/auto_examples/solutions/plot_fft_image_denoise.html
216
+
217
+ im_fft = fftpack.fft2(img)
218
+
219
+ # Call ff a copy of the original transform. Numpy arrays have a copy
220
+ # method for this purpose.
221
+ im_fft2 = im_fft
222
+
223
+ # Set r and c to be the number of rows and columns of the array.
224
+ r, c, _ = im_fft2.shape
225
+
226
+ # Set to zero all rows with indices between r*keep_fraction and r*(1-keep_fraction):
227
+ if row_cut == True:
228
+ im_fft2[int(r * keep_fraction):int(r * (1 - keep_fraction))] = 0
229
+
230
+ # Similarly with the columns:
231
+ if column_cut == True:
232
+ im_fft2[:, int(c * keep_fraction):int(c * (1 - keep_fraction))] = 0
233
+
234
+ # Reconstruct the denoised image from the filtered spectrum, keep only the
235
+ # real part for display.
236
+ im_new = fftpack.ifft2(im_fft2).real
237
+
238
+ return im_new
239
+
240
+
241
+ def adjust_gamma(img, gamma=1.0):
242
+ invGamma = 1.0 / gamma
243
+ img = (img ** invGamma)
244
+ return img
245
+
246
+
247
+ def show_img(img, title="no_title", size=12, histo=True, bins=300, bits=16, x_range=-1):
248
+ """Plot image and its histogram
249
+
250
+ Args:
251
+ img (ndarray): image to plot
252
+ title (str): title of the plot
253
+ histo (bool): True - Plot histrograms per channel of the image. False - Plot the curve of histogram in a continue way
254
+ bins (int): number of bins of the histogram
255
+ size (int): figure size
256
+ bits (int): number of bits per pixel in the ndarray
257
+ x_range (list): maximum x range of the histogram (if -1 it will be take all x values)
258
+ """
259
+ shape = img.shape
260
+
261
+ fig = plt.figure(figsize=(size, size))
262
+
263
+ # show original image
264
+ fig.add_subplot(221)
265
+ if len(shape) > 2 and img.max() > 255:
266
+ img_to_show = (img.copy() * 255. / (2**bits - 1)).astype(int)
267
+ else:
268
+ img_to_show = img.copy().astype(int)
269
+ plt.imshow(img_to_show)
270
+ if title != "no_title":
271
+ plt.title(title)
272
+
273
+ fig.add_subplot(222)
274
+
275
+ if len(shape) > 2:
276
+ if histo == True:
277
+ plt.hist(img[:, :, 0].flatten(), bins=bins, label="Channel1", color="red", alpha=0.5)
278
+ plt.hist(img[:, :, 1].flatten(), bins=bins, label="Channel2", color="green", alpha=0.5)
279
+ plt.hist(img[:, :, 2].flatten(), bins=bins, label="Channel3", color="blue", alpha=0.5)
280
+ if x_range != -1:
281
+ plt.xlim([x_range[0], x_range[1]])
282
+ else:
283
+ h1, b1 = np.histogram(img[:, :, 0].flatten(), bins=bins)
284
+ h2, b2 = np.histogram(img[:, :, 1].flatten(), bins=bins)
285
+ h3, b3 = np.histogram(img[:, :, 2].flatten(), bins=bins)
286
+ plt.plot(b1[:-1], h1, label="Channel1", color="red", alpha=0.5)
287
+ plt.plot(b2[:-1], h2, label="Channel2", color="green", alpha=0.5)
288
+ plt.plot(b3[:-1], h3, label="Channel3", color="blue", alpha=0.5)
289
+
290
+ plt.legend()
291
+ else:
292
+ if histo == True:
293
+ plt.hist(img.flatten(), bins=bins)
294
+ if x_range != -1:
295
+ plt.xlim([x_range[0], x_range[1]])
296
+ else:
297
+ h, b = np.histogram(img.flatten(), bins=bins)
298
+ plt.plot(b[:-1], h)
299
+
300
+ plt.xlabel("Intensities")
301
+ plt.ylabel("Counts")
302
+
303
+ plt.show()
304
+
305
+
306
+ def get_statistics(dataset, train_indices, transform=None):
307
+ """Calculates the mean and the standard deviation of a given sub train set of dataset
308
+
309
+ Args:
310
+ dataset (Subset of DroneDataset):
311
+ train_indices (tensor): indicies correponding to a subset of the dataset
312
+ transform (Compose): list of transformations compatible with Compose to be applied before calculations
313
+ return:
314
+ mean (tensor of dtype float): size (C,1,1)
315
+ std (tensor of dtype float): size (C,1,1)
316
+ """
317
+
318
+ trainset = Subset(dataset, indices=train_indices, transform=transform)
319
+ dataloader = DataLoader(trainset, batch_size=len(trainset), shuffle=False)
320
+ dataiter = iter(dataloader)
321
+
322
+ images, labels = dataiter.next()
323
+
324
+ if len(images.shape) == 3:
325
+ mean, std = torch.mean(images, axis=(0, 1, 2)), torch.std(images, axis=(0, 1, 2))
326
+ return mean, std
327
+ else:
328
+ mean, std = torch.mean(images, axis=(0, 2, 3))[:, None, None], torch.std(images, axis=(0, 2, 3))[:, None, None]
329
+ return mean, std
processing/pipeline_torch.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from numpy.lib.function_base import interp
3
+ import torch
4
+ import torch.nn as nn
5
+ if not os.path.exists('README.md'):
6
+ os.chdir('..')
7
+
8
+ from processing.pipeline_numpy import processing as default_processing
9
+ from utils.base import np2torch, torch2np
10
+
11
+ import segmentation_models_pytorch as smp
12
+
13
+ from utils.debug import debug
14
+
15
+ K_G = torch.Tensor([[0, 1, 0],
16
+ [1, 4, 1],
17
+ [0, 1, 0]]) / 4
18
+
19
+ K_RB = torch.Tensor([[1, 2, 1],
20
+ [2, 4, 2],
21
+ [1, 2, 1]]) / 4
22
+
23
+ M_RGB_2_YUV = torch.Tensor([[0.299, 0.587, 0.114],
24
+ [-0.14714119, -0.28886916, 0.43601035],
25
+ [0.61497538, -0.51496512, -0.10001026]])
26
+ M_YUV_2_RGB = torch.Tensor([[1.0000000000e+00, -4.1827794561e-09, 1.1398830414e+00],
27
+ [1.0000000000e+00, -3.9464232326e-01, -5.8062183857e-01],
28
+ [1.0000000000e+00, 2.0320618153e+00, -1.2232658220e-09]])
29
+
30
+ K_BLUR = torch.Tensor([[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08],
31
+ [2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
32
+ [2.0755e-04, 8.3731e-02, 6.1869e-01, 8.3731e-02, 2.0755e-04],
33
+ [2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
34
+ [6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08]])
35
+ K_SHARP = torch.Tensor([[0, -1, 0],
36
+ [-1, 5, -1],
37
+ [0, -1, 0]])
38
+ DEFAULT_CAMERA_PARAMS = (
39
+ [0., 0., 0., 0.],
40
+ [1., 1., 1.],
41
+ [1., 0., 0., 0., 1., 0., 0., 0., 1.],
42
+ )
43
+
44
+
45
+ class RawToRGB(nn.Module):
46
+ def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None):
47
+ super().__init__()
48
+ self.stages = None
49
+ self.buffer = None
50
+ self.reduce_size = reduce_size
51
+ self.out_channels = out_channels
52
+ self.track_stages = track_stages
53
+ self.normalize_mosaic = normalize_mosaic
54
+
55
+ def forward(self, raw):
56
+ self.stages = {}
57
+ self.buffer = {}
58
+
59
+ rgb = raw2rgb(raw, reduce_size=self.reduce_size, out_channels=self.out_channels)
60
+ self.stages['demosaic'] = rgb
61
+ if self.normalize_mosaic:
62
+ rgb = self.normalize_mosaic(rgb)
63
+
64
+ if self.track_stages and raw.requires_grad:
65
+ for stage in self.stages.values():
66
+ stage.retain_grad()
67
+
68
+ self.buffer['processed_rgb'] = rgb
69
+
70
+ return rgb
71
+
72
+
73
+ class NNProcessing(nn.Module):
74
+ def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True):
75
+ super().__init__()
76
+ self.stages = None
77
+ self.buffer = None
78
+ self.track_stages = track_stages
79
+ self.model = smp.UnetPlusPlus(
80
+ encoder_name='resnet34',
81
+ encoder_depth=3,
82
+ decoder_channels=[256, 128, 64],
83
+ in_channels=3,
84
+ classes=3,
85
+ )
86
+ self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3, affine=False)
87
+ self.normalize_mosaic = normalize_mosaic
88
+
89
+ def forward(self, raw):
90
+ self.stages = {}
91
+ self.buffer = {}
92
+ # self.stages['raw'] = raw
93
+ rgb = raw2rgb(raw)
94
+ if self.normalize_mosaic:
95
+ rgb = self.normalize_mosaic(rgb)
96
+ self.stages['demosaic'] = rgb
97
+ rgb = self.model(rgb)
98
+ if self.batch_norm is not None:
99
+ rgb = self.batch_norm(rgb)
100
+ self.stages['rgb'] = rgb
101
+
102
+ if self.track_stages and raw.requires_grad:
103
+ for stage in self.stages.values():
104
+ stage.retain_grad()
105
+
106
+ self.buffer['processed_rgb'] = rgb
107
+
108
+ return rgb
109
+
110
+
111
+ def add_additive_layer(processor):
112
+ processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256)))
113
+ # processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256)))
114
+
115
+
116
+ class ParametrizedProcessing(nn.Module):
117
+ def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True):
118
+ super().__init__()
119
+ self.stages = None
120
+ self.buffer = None
121
+ self.track_stages = track_stages
122
+
123
+ black_level, white_balance, colour_matrix = camera_parameters
124
+
125
+ self.black_level = nn.Parameter(torch.as_tensor(black_level))
126
+ self.white_balance = nn.Parameter(torch.as_tensor(white_balance).reshape(1, 3))
127
+ self.colour_correction = nn.Parameter(torch.as_tensor(colour_matrix).reshape(3, 3))
128
+
129
+ self.gamma_correct = nn.Parameter(torch.Tensor([2.2]))
130
+
131
+ self.debayer = Debayer()
132
+
133
+ self.sharpening_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
134
+ self.sharpening_filter.weight.data[0][0] = K_SHARP.clone()
135
+
136
+ self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False)
137
+ self.gaussian_blur.weight.data[0][0] = K_BLUR.clone()
138
+
139
+ self.batch_norm = nn.BatchNorm2d(3, affine=False) if batch_norm_output else None
140
+
141
+ self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone())
142
+ self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone())
143
+
144
+ self.additive_layer = None # this can be added in later
145
+
146
+ def forward(self, raw):
147
+ assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}"
148
+
149
+ self.stages = {}
150
+ self.buffer = {}
151
+
152
+ # self.stages['raw'] = raw
153
+
154
+ rgb = raw2rgb(raw, black_level=self.black_level, reduce_size=False)
155
+ rgb = rgb.contiguous()
156
+ self.stages['demosaic'] = rgb
157
+
158
+ rgb = self.debayer(rgb)
159
+ # self.stages['debayer'] = rgb
160
+
161
+ rgb = torch.einsum('bchw,kc->bchw', rgb, self.white_balance).contiguous()
162
+ rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous()
163
+ self.stages['color_correct'] = rgb
164
+
165
+ yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
166
+ yuv[:, [0], ...] = self.sharpening_filter(yuv[:, [0], ...])
167
+
168
+ if self.track_stages: # keep stage in computational graph for grad information
169
+ rgb = torch.einsum('bchw,kc->bkhw', yuv.clone(), self.M_YUV_2_RGB).contiguous()
170
+ self.stages['sharpening'] = rgb
171
+ yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
172
+
173
+ yuv[:, [0], ...] = self.gaussian_blur(yuv[:, [0], ...])
174
+ rgb = torch.einsum('bchw,kc->bkhw', yuv, self.M_YUV_2_RGB).contiguous()
175
+ self.stages['gaussian'] = rgb
176
+
177
+ rgb = torch.clip(rgb, 1e-5, 1)
178
+ self.stages['clipped'] = rgb
179
+
180
+ rgb = torch.exp((1 / self.gamma_correct) * torch.log(rgb))
181
+ self.stages['gamma_correct'] = rgb
182
+
183
+ if self.additive_layer is not None:
184
+ rgb = rgb + self.additive_layer
185
+ self.stages['noise'] = rgb
186
+
187
+ if self.batch_norm is not None:
188
+ rgb = self.batch_norm(rgb)
189
+
190
+ if self.track_stages and raw.requires_grad:
191
+ for stage in self.stages.values():
192
+ stage.retain_grad()
193
+
194
+ self.buffer['processed_rgb'] = rgb
195
+
196
+ return rgb
197
+
198
+
199
+ class Debayer(nn.Conv2d):
200
+ def __init__(self):
201
+ super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) # default_pipeline uses 'replicate'
202
+ self.weight.data.fill_(0)
203
+ self.weight.data[0, 0] = K_RB.clone()
204
+ self.weight.data[1, 1] = K_G.clone()
205
+ self.weight.data[2, 2] = K_RB.clone()
206
+
207
+
208
+ def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3):
209
+ """transform raw image with 1 channel to rgb with 3 channels
210
+ Args:
211
+ raw (Tensor): raw Tensor of shape (B, H, W)
212
+ black_level (iterable, optional): RGGB black level values to subtract
213
+ reduce_size (bool, optional): if False, the output image will have the same height and width
214
+ as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
215
+ if True, the output dimensions are reduced by half (B, C, H//2, W//2),
216
+ the two green channels are averaged.
217
+ out_channels (int, optional): number of output channels. One of {3, 4}.
218
+ """
219
+ assert out_channels in [3, 4]
220
+ if black_level is None:
221
+ black_level = [0, 0, 0, 0]
222
+ Bch, H, W = raw.shape
223
+ R = raw[:, 0::2, 0::2] - black_level[0] # R
224
+ G1 = raw[:, 0::2, 1::2] - black_level[1] # G
225
+ G2 = raw[:, 1::2, 0::2] - black_level[2] # G
226
+ B = raw[:, 1::2, 1::2] - black_level[3] # B
227
+ if reduce_size:
228
+ rgb = torch.zeros((Bch, out_channels, H // 2, W // 2), device=raw.device)
229
+ if out_channels == 3:
230
+ rgb[:, 0, :, :] = R
231
+ rgb[:, 1, :, :] = (G1 + G2) / 2
232
+ rgb[:, 2, :, :] = B
233
+ elif out_channels == 4:
234
+ rgb[:, 0, :, :] = R
235
+ rgb[:, 1, :, :] = G1
236
+ rgb[:, 2, :, :] = G2
237
+ rgb[:, 3, :, :] = B
238
+ else:
239
+ rgb = torch.zeros((Bch, out_channels, H, W), device=raw.device)
240
+ if out_channels == 3:
241
+ rgb[:, 0, 0::2, 0::2] = R
242
+ rgb[:, 1, 0::2, 1::2] = G1
243
+ rgb[:, 1, 1::2, 0::2] = G2
244
+ rgb[:, 2, 1::2, 1::2] = B
245
+ elif out_channels == 4:
246
+ rgb[:, 0, 0::2, 0::2] = R
247
+ rgb[:, 1, 0::2, 1::2] = G1
248
+ rgb[:, 2, 1::2, 0::2] = G2
249
+ rgb[:, 3, 1::2, 1::2] = B
250
+ return rgb
251
+
252
+
253
+ # pipeline validation
254
+ if __name__ == "__main__":
255
+
256
+ import torch
257
+ import numpy as np
258
+
259
+ if not os.path.exists('README.md'):
260
+ os.chdir('..')
261
+
262
+ import matplotlib.pyplot as plt
263
+ from dataset import get_dataset
264
+ from utils.base import np2torch, torch2np
265
+
266
+ from utils.debug import debug
267
+ from processing.pipeline_numpy import processing as default_processing
268
+
269
+ raw_dataset = get_dataset('DS')
270
+ loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1)
271
+ batch_raw, batch_mask = next(iter(loader))
272
+
273
+ # torch proc
274
+ camera_parameters = raw_dataset.camera_parameters
275
+ black_level = camera_parameters[0]
276
+
277
+ proc = ParametrizedProcessing(camera_parameters)
278
+
279
+ batch_rgb = proc(batch_raw)
280
+ rgb = batch_rgb[0]
281
+
282
+ # numpy proc
283
+ raw_img = batch_raw[0]
284
+ numpy_raw = torch2np(raw_img)
285
+
286
+ default_rgb = default_processing(numpy_raw, *camera_parameters,
287
+ sharpening='sharpening_filter', denoising='gaussian_denoising')
288
+
289
+ rgb_valid = np2torch(default_rgb)
290
+
291
+ print("pipeline norm difference:", (rgb - rgb_valid).norm().item())
292
+
293
+ rgb_mosaic = raw2rgb(batch_raw, reduce_size=False).squeeze()
294
+ rgb_reduced = raw2rgb(batch_raw, reduce_size=True).squeeze()
295
+
296
+ plt.figure(figsize=(16, 8))
297
+ plt.subplot(151)
298
+ plt.title('Raw')
299
+ plt.imshow(torch2np(raw_img))
300
+ plt.subplot(152)
301
+ plt.title('RGB Mosaic')
302
+ plt.imshow(torch2np(rgb_mosaic))
303
+ plt.subplot(153)
304
+ plt.title('RGB Reduced')
305
+ plt.imshow(torch2np(rgb_reduced))
306
+ plt.subplot(154)
307
+ plt.title('Torch Pipeline')
308
+ plt.imshow(torch2np(rgb))
309
+ plt.subplot(155)
310
+ plt.title('Default Pipeline')
311
+ plt.imshow(torch2np(rgb_valid))
312
+ plt.show()
313
+
314
+ # assert rgb.allclose(rgb_valid)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scikit_image
2
+ pandas
3
+ scipy
4
+ numpy
5
+ matplotlib
6
+ b2sdk
7
+ colour_demosaicing
8
+ gradio
9
+ ipython
10
+ mlflow
11
+ Pillow
12
+ pytorch_toolbelt
13
+ rawpy
14
+ scikit_learn
15
+ segmentation_models_pytorch
16
+ tifffile
17
+ torch==1.9.0
18
+ torchvision==0.10.0
train.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import argparse
5
+
6
+ import torch
7
+ from torch import optim
8
+ import torch.nn as nn
9
+
10
+ import mlflow.pytorch
11
+ from torch.utils.data import DataLoader
12
+ from torchvision.models import resnet18
13
+ import torchvision.transforms as T
14
+ from pytorch_lightning.metrics.functional import accuracy
15
+ import pytorch_lightning as pl
16
+ from pytorch_lightning.callbacks import ModelCheckpoint
17
+
18
+ from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
19
+ from utils.debug import debug
20
+ from utils.dataset_utils import k_fold
21
+ from utils.augmentation import get_augmentation
22
+ from dataset import Subset, get_dataset
23
+
24
+ from processing.pipeline_numpy import RawProcessingPipeline
25
+ from processing.pipeline_torch import add_additive_layer, raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing
26
+
27
+ from model import log_tensor, resnet_model, LitModel, TrackImagesCallback
28
+
29
+ import segmentation_models_pytorch as smp
30
+
31
+ from utils.ssim import SSIM
32
+
33
+ # args to set up task
34
+ parser = argparse.ArgumentParser(description="classification_task")
35
+ parser.add_argument("--tracking_uri", type=str,
36
+ default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS')
37
+ parser.add_argument("--processor_uri", type=str, default=None,
38
+ help='URI of the processing model (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/processing-model)')
39
+ parser.add_argument("--classifier_uri", type=str, default=None,
40
+ help='URI of the net (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/prediction-model)')
41
+ parser.add_argument("--state_dict_uri", type=str,
42
+ default=None, help='URI of the indices you want to load (e.g. s3://mlflow-artifacts-601883093460/7/4326da05aca54107be8c554de0674a14/artifacts/training')
43
+
44
+ parser.add_argument("--experiment_name", type=str,
45
+ default='classification learnable pipeline', help='Specify the experiment you are running, e.g. end2end segmentation')
46
+ parser.add_argument("--run_name", type=str,
47
+ default='test run', help='Specify the name of your run')
48
+
49
+ parser.add_argument("--log_model", type=str2bool, default=True, help='Enables model logging')
50
+ parser.add_argument("--save_locally", action='store_true',
51
+ help='Model will be saved locally if action is taken') # TODO: bypass mlflow
52
+
53
+ parser.add_argument("--track_processing", action='store_true',
54
+ help='Save images after each trasformation of the pipeline for the test set')
55
+ parser.add_argument("--track_processing_gradients", action='store_true',
56
+ help='Save images of gradients after each trasformation of the pipeline for the test set')
57
+ parser.add_argument("--track_save_tensors", action='store_true',
58
+ help='Save the torch tensors after each trasformation of the pipeline for the test set')
59
+ parser.add_argument("--track_predictions", action='store_true',
60
+ help='Save images after each trasformation of the pipeline for the test set + input gradient')
61
+ parser.add_argument("--track_n_images", default=5,
62
+ help='Track the n first elements of dataset. Only used for args.track_processing=True')
63
+ parser.add_argument("--track_every_epoch", action='store_true', help='Track images every epoch or once after training')
64
+
65
+ # args to create dataset
66
+ parser.add_argument("--seed", type=int, default=1, help='Global seed')
67
+ parser.add_argument("--dataset", type=str, default='Microscopy',
68
+ choices=["Drone", "DroneSegmentation", "Microscopy"], help='Select dataset')
69
+
70
+ parser.add_argument("--n_splits", type=int, default=1, help='Number of splits used for training')
71
+ parser.add_argument("--train_size", type=float, default=0.8, help='Fraction of training points in dataset')
72
+
73
+ # args for training
74
+ parser.add_argument("--lr", type=float, default=1e-5, help="learning rate used for training")
75
+ parser.add_argument("--epochs", type=int, default=3, help="numper of epochs")
76
+ parser.add_argument("--batch_size", type=int, default=32, help="Training batch size")
77
+ parser.add_argument("--augmentation", type=str, default='none',
78
+ choices=["none", "weak", "strong"], help="Applies augmentation to training")
79
+ parser.add_argument("--augmentation_on_valid_epoch", action='store_true',
80
+ help='Track images every epoch or once after training') # TODO: implement, actually should be disabled by default for 'val' and 'test
81
+ parser.add_argument("--check_val_every_n_epoch", type=int, default=1)
82
+
83
+ # args to specify the processing
84
+ parser.add_argument("--processing_mode", type=str, default="parametrized",
85
+ choices=["parametrized", "static", "neural_network", "none"],
86
+ help="Which type of raw to rgb processing should be used")
87
+
88
+ # args to specify model
89
+ parser.add_argument("--classifier_network", type=str, default='ResNet18',
90
+ help='Type of pretrained network') # TODO: implement different choices
91
+ parser.add_argument("--classifier_pretrained", action='store_true',
92
+ help='Whether to use a pre-trained model or not')
93
+ parser.add_argument("--smp_encoder", type=str, default='resnet34', help='segmentation model encoder')
94
+
95
+ parser.add_argument("--freeze_processor", action='store_true', help="Freeze raw to rgb processing model weights")
96
+ parser.add_argument("--freeze_classifier", action='store_true', help="Freeze classification model weights")
97
+
98
+ # args to specify static pipeline transformations
99
+ parser.add_argument("--sp_debayer", type=str, default='bilinear',
100
+ choices=['bilinear', 'malvar2004', 'menon2007'], help="Specify algorithm used as debayer")
101
+ parser.add_argument("--sp_sharpening", type=str, default='sharpening_filter',
102
+ choices=['sharpening_filter', 'unsharp_masking'], help="Specify algorithm used for sharpening")
103
+ parser.add_argument("--sp_denoising", type=str, default='gaussian_denoising',
104
+ choices=['gaussian_denoising', 'median_denoising', 'fft_denoising'], help="Specify algorithm used for denoising")
105
+
106
+ # args to choose training mode
107
+ parser.add_argument("--adv_training", action='store_true', help="Enable adversarial training")
108
+ parser.add_argument("--adv_aux_weight", type=float, default=1, help="Weighting of the adversarial auxilliary loss")
109
+ parser.add_argument("--adv_aux_loss", type=str, default='ssim', choices=['l2', 'ssim'],
110
+ help="Type of adversarial auxilliary regularization loss")
111
+ parser.add_argument("--adv_noise_layer", action='store_true', help="Adds an additive layer to Parametrized Processing")
112
+ parser.add_argument("--adv_track_differences", action='store_true', help='Save difference to default pipeline')
113
+ parser.add_argument('--adv_parameters', choices=['all', 'black_level', 'white_balance',
114
+ 'colour_correction', 'gamma_correct', 'sharpening_filter', 'gaussian_blur', 'additive_layer'])
115
+
116
+ parser.add_argument("--cache_downloaded_models", type=str2bool, default=True)
117
+
118
+ parser.add_argument('--test_run', action='store_true')
119
+
120
+
121
+ args = parser.parse_args()
122
+
123
+ os.makedirs('results', exist_ok=True)
124
+
125
+
126
+ def run_train(args):
127
+
128
+ print(args)
129
+
130
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
131
+ training_mode = 'adversarial' if args.adv_training else 'default'
132
+
133
+ # set tracking uri, this is the address of the mlflow server where light experimental data will be stored
134
+ mlflow.set_tracking_uri(args.tracking_uri)
135
+ mlflow.set_experiment(args.experiment_name)
136
+ os.environ["AWS_ACCESS_KEY_ID"] = "#TODO: fill in your aws access key id for mlflow server here"
137
+ os.environ["AWS_SECRET_ACCESS_KEY"] = "#TODO: fill in your aws secret access key for mlflow server here"
138
+
139
+ # dataset
140
+
141
+ dataset = get_dataset(args.dataset)
142
+
143
+ print(f'dataset: {type(dataset).__name__}[{len(dataset)}]')
144
+ print(f'task: {dataset.task}')
145
+ print(f'mode: {training_mode} training')
146
+ print(f'# cross-validation subsets: {args.n_splits}')
147
+ pl.seed_everything(args.seed)
148
+ idxs_kfold = k_fold(dataset, n_splits=args.n_splits, seed=args.seed, train_size=args.train_size)
149
+
150
+ with mlflow.start_run(run_name=args.run_name) as parent_run:
151
+
152
+ for k_iter, idxs in enumerate(idxs_kfold):
153
+
154
+ print(f"K_fold subset: {k_iter+1}/{args.n_splits}")
155
+
156
+ if args.processing_mode == 'static':
157
+ if args.dataset == "Drone" or args.dataset == "DroneSegmentation":
158
+ mean = torch.tensor([0.35, 0.36, 0.35])
159
+ std = torch.tensor([0.12, 0.11, 0.12])
160
+ elif args.dataset == "Microscopy":
161
+ mean = torch.tensor([0.91, 0.84, 0.94])
162
+ std = torch.tensor([0.08, 0.12, 0.05])
163
+
164
+ dataset.transform = T.Compose([RawProcessingPipeline(
165
+ camera_parameters=dataset.camera_parameters,
166
+ debayer=args.sp_debayer,
167
+ sharpening=args.sp_sharpening,
168
+ denoising=args.sp_denoising,
169
+ ), T.Normalize(mean, std)])
170
+ # XXX: Not clean
171
+
172
+ processor = nn.Identity()
173
+
174
+ if args.processor_uri is not None and args.processing_mode != 'none':
175
+ print('Fetching processor: ', end='')
176
+ model = fetch_from_mlflow(args.processor_uri, use_cache=args.cache_downloaded_models)
177
+ processor = model.processor
178
+ for param in processor.parameters():
179
+ param.requires_grad = True
180
+ model.processor = None
181
+ del model
182
+ else:
183
+ print(f'processing_mode: {args.processing_mode}')
184
+ normalize_mosaic = None # normalize after raw has been passed to raw2rgb
185
+ if args.dataset == "Microscopy":
186
+ mosaic_mean = [0.5663, 0.1401, 0.0731]
187
+ mosaic_std = [0.097, 0.0423, 0.008]
188
+ normalize_mosaic = T.Normalize(mosaic_mean, mosaic_std)
189
+
190
+ track_stages = args.track_processing or args.track_processing_gradients
191
+ if args.processing_mode == 'parametrized':
192
+ processor = ParametrizedProcessing(
193
+ camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True,
194
+ # noise_layer=args.adv_noise_layer, # this has to be added manually afterwards for when a model is loaded that doesn't have one yet
195
+ )
196
+
197
+ elif args.processing_mode == 'neural_network':
198
+ processor = NNProcessing(track_stages=track_stages,
199
+ normalize_mosaic=normalize_mosaic, batch_norm_output=True)
200
+ elif args.processing_mode == 'none':
201
+ processor = RawToRGB(reduce_size=True, out_channels=3, track_stages=track_stages,
202
+ normalize_mosaic=normalize_mosaic)
203
+
204
+ if args.classifier_uri: # fetch classifier
205
+ print('Fetching classifier: ', end='')
206
+ model = fetch_from_mlflow(args.classifier_uri, use_cache=args.cache_downloaded_models)
207
+ classifier = model.classifier
208
+ model.classifier = None
209
+ del model
210
+ else:
211
+ if dataset.task == 'classification':
212
+ classifier = resnet_model(
213
+ model=resnet18,
214
+ pretrained=args.classifier_pretrained,
215
+ in_channels=3,
216
+ fc_out_features=len(dataset.classes)
217
+ )
218
+ else:
219
+ # XXX: add other network choices to args.smp_network (FPN) and args.network
220
+ classifier = smp.UnetPlusPlus(
221
+ encoder_name=args.smp_encoder,
222
+ encoder_depth=5,
223
+ encoder_weights='imagenet',
224
+ in_channels=3,
225
+ classes=1,
226
+ activation=None,
227
+ )
228
+
229
+ if args.freeze_processor and len(list(iter(processor.parameters()))) == 0:
230
+ print('Note: freezing processor without parameters.')
231
+ assert not (args.freeze_processor and args.freeze_classifier), 'Likely no parameters to train.'
232
+
233
+ if dataset.task == 'classification':
234
+ loss = nn.CrossEntropyLoss()
235
+ metrics = [accuracy]
236
+ else:
237
+ # loss = utils.base.smp_get_loss(args.smp_loss) # XXX: add other losses to args.smp_loss
238
+ loss = smp.losses.DiceLoss(mode='binary', from_logits=True)
239
+ metrics = [smp.utils.metrics.IoU()]
240
+
241
+ loss_aux = None
242
+
243
+ if args.adv_training:
244
+
245
+ assert args.processing_mode == 'parametrized', f"Processing mode ({args.processing_mode}) should be set to 'parametrized' for adversarial training"
246
+ assert args.freeze_classifier, "Classifier should be frozen for adversarial training"
247
+ assert not args.freeze_processor, "Processor should not be frozen for adversarial training"
248
+
249
+ processor_default = copy.deepcopy(processor)
250
+ processor_default.track_stages = args.track_processing
251
+ processor_default.eval()
252
+ processor_default.to(DEVICE)
253
+ # debug(processor_default)
254
+ for p in processor_default.parameters():
255
+ p.requires_grad = False
256
+
257
+ if args.adv_noise_layer:
258
+ add_additive_layer(processor)
259
+
260
+ def l2_regularization(x, y):
261
+ return ((x - y) ** 2).sum()
262
+ # return (x - y).norm()
263
+
264
+ if args.adv_aux_loss == 'l2':
265
+ regularization = l2_regularization
266
+ elif args.adv_aux_loss == 'ssim':
267
+ regularization = SSIM(window_size=11)
268
+ else:
269
+ NotImplementedError(args.adv_aux_loss)
270
+
271
+ class AuxLoss(nn.Module):
272
+ def __init__(self, loss_aux, weight=1):
273
+ super().__init__()
274
+ self.loss_aux = loss_aux
275
+ self.weight = weight
276
+
277
+ def forward(self, x):
278
+ with torch.no_grad():
279
+ x_reference = processor_default(x)
280
+ x_processed = processor.buffer['processed_rgb']
281
+ return self.weight * self.loss_aux(x_reference, x_processed)
282
+
283
+ class WeightedLoss(nn.Module):
284
+ def __init__(self, loss, weight=1):
285
+ super().__init__()
286
+ self.loss = loss
287
+ self.weight = weight
288
+
289
+ def forward(self, x, y):
290
+ return self.weight * self.loss(x, y)
291
+
292
+ def __repr__(self):
293
+ return f'{self.weight} * {get_name(self.loss)}'
294
+
295
+ loss = WeightedLoss(loss=loss, weight=-1)
296
+ # loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=0)
297
+ loss_aux = AuxLoss(
298
+ loss_aux=regularization,
299
+ weight=args.adv_aux_weight,
300
+ )
301
+
302
+ augmentation = get_augmentation(args.augmentation)
303
+
304
+ model = LitModel(
305
+ classifier=classifier,
306
+ processor=processor,
307
+ loss=loss,
308
+ lr=args.lr,
309
+ loss_aux=loss_aux,
310
+ adv_training=args.adv_training,
311
+ adv_parameters=args.adv_parameters,
312
+ metrics=metrics,
313
+ augmentation=augmentation,
314
+ is_segmentation_task=dataset.task == 'segmentation',
315
+ freeze_classifier=args.freeze_classifier,
316
+ freeze_processor=args.freeze_processor,
317
+ )
318
+
319
+ # get train_set_dict
320
+ if args.state_dict_uri:
321
+ state_dict = mlflow.pytorch.load_state_dict(args.state_dict_uri)
322
+ train_indices = state_dict['train_indices']
323
+ valid_indices = state_dict['valid_indices']
324
+ else:
325
+ train_indices = idxs[0]
326
+ valid_indices = idxs[1]
327
+ state_dict = vars(args).copy()
328
+
329
+ track_indices = list(range(args.track_n_images))
330
+
331
+ if dataset.task == 'classification':
332
+ state_dict['classes'] = dataset.classes
333
+ state_dict['device'] = DEVICE
334
+ state_dict['train_indices'] = train_indices
335
+ state_dict['valid_indices'] = valid_indices
336
+ state_dict['elements in train set'] = len(train_indices)
337
+ state_dict['elements in test set'] = len(valid_indices)
338
+
339
+ if args.test_run:
340
+ train_indices = train_indices[:args.batch_size]
341
+ valid_indices = valid_indices[:args.batch_size]
342
+
343
+ train_set = Subset(dataset, indices=train_indices)
344
+ valid_set = Subset(dataset, indices=valid_indices)
345
+ track_set = Subset(dataset, indices=track_indices)
346
+
347
+ train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=16, shuffle=True)
348
+ valid_loader = DataLoader(valid_set, batch_size=args.batch_size, num_workers=16, shuffle=False)
349
+ track_loader = DataLoader(track_set, batch_size=args.batch_size, num_workers=16, shuffle=False)
350
+
351
+ with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
352
+
353
+ # mlflow.pytorch.autolog(silent=True)
354
+
355
+ if k_iter == 0:
356
+ display_mlflow_run_info(child_run)
357
+
358
+ mlflow.pytorch.log_state_dict(state_dict, artifact_path=None)
359
+
360
+ hparams = {
361
+ 'dataset': args.dataset,
362
+ 'processing_mode': args.processing_mode,
363
+ 'training_mode': training_mode,
364
+ }
365
+ if training_mode == 'adversarial':
366
+ hparams['adv_aux_weight'] = args.adv_aux_weight
367
+ hparams['adv_aux_loss'] = args.adv_aux_loss
368
+
369
+ mlflow.log_params(hparams)
370
+
371
+ with open('results/state_dict.txt', 'w') as f:
372
+ f.write('python ' + ' '.join(sys.argv) + '\n')
373
+ f.write('\n'.join([f'{k}={v}' for k, v in state_dict.items()]))
374
+ mlflow.log_artifact('results/state_dict.txt', artifact_path=None)
375
+
376
+ mlf_logger = pl.loggers.MLFlowLogger(experiment_name=args.experiment_name,
377
+ tracking_uri=args.tracking_uri,)
378
+ mlf_logger._run_id = child_run.info.run_id
379
+
380
+ reference_processor = processor_default if args.adv_training and args.adv_track_differences else None
381
+
382
+ callbacks = []
383
+ if args.track_processing:
384
+ callbacks += [TrackImagesCallback(track_loader,
385
+ reference_processor,
386
+ track_every_epoch=args.track_every_epoch,
387
+ track_processing=args.track_processing,
388
+ track_gradients=args.track_processing_gradients,
389
+ track_predictions=args.track_predictions,
390
+ save_tensors=args.track_save_tensors)]
391
+
392
+ # if True: #args.save_best:
393
+ # if dataset.task == 'classification':
394
+ #checkpoint_callback = ModelCheckpoint(pathmonitor="val_accuracy", mode='max')
395
+ # checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri,
396
+ # else:
397
+ # checkpoint_callback = ModelCheckpoint(monitor="val_iou_score")
398
+ #callbacks += [checkpoint_callback]
399
+
400
+ trainer = pl.Trainer(
401
+ gpus=1 if DEVICE == 'cuda' else 0,
402
+ min_epochs=args.epochs,
403
+ max_epochs=args.epochs,
404
+ logger=mlf_logger,
405
+ callbacks=callbacks,
406
+ check_val_every_n_epoch=args.check_val_every_n_epoch,
407
+ # checkpoint_callback=True,
408
+ )
409
+
410
+ if args.log_model:
411
+ mlflow.pytorch.autolog(log_every_n_epoch=10)
412
+ print(f'model_uri="{mlflow.get_artifact_uri()}/model"')
413
+
414
+ t = trainer.fit(
415
+ model,
416
+ train_dataloader=train_loader,
417
+ val_dataloaders=valid_loader,
418
+ )
419
+
420
+ globals().update(locals()) # for convenient access
421
+
422
+ return model
423
+
424
+
425
+ if __name__ == '__main__':
426
+ model = run_train(args)
utils/augmentation.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torchvision.transforms as T
6
+
7
+
8
+ class RandomRotate90(): # Note: not the same as T.RandomRotation(90)
9
+ def __call__(self, x):
10
+ x = x.rot90(random.randint(0, 3), dims=(-1, -2))
11
+ return x
12
+
13
+ def __repr__(self):
14
+ return self.__class__.__name__
15
+
16
+
17
+ class AddGaussianNoise():
18
+ def __init__(self, std=0.01):
19
+ self.std = std
20
+
21
+ def __call__(self, x):
22
+ # noise = torch.randn_like(x) * self.std
23
+ # out = x + noise
24
+ # debug(x)
25
+ # debug(noise)
26
+ # debug(out)
27
+ return x + torch.randn_like(x) * self.std
28
+
29
+ def __repr__(self):
30
+ return self.__class__.__name__ + f'(std={self.std})'
31
+
32
+
33
+ def set_global_seed(seed):
34
+ torch.random.manual_seed(seed)
35
+ np.random.seed(seed % (2**32 - 1))
36
+ random.seed(seed)
37
+
38
+
39
+ class ComposeState(T.Compose):
40
+ def __init__(self, transforms):
41
+ self.transforms = []
42
+ self.mask_transforms = []
43
+
44
+ for t in transforms:
45
+ apply_for_mask = True
46
+ if isinstance(t, tuple):
47
+ t, apply_for_mask = t
48
+ self.transforms.append(t)
49
+ if apply_for_mask:
50
+ self.mask_transforms.append(t)
51
+
52
+ self.seed = None
53
+
54
+ # @debug
55
+ def __call__(self, x, retain_state=False, mask_transform=False):
56
+ if self.seed is not None: # retain previous state
57
+ set_global_seed(self.seed)
58
+ if retain_state: # save state for next call
59
+ self.seed = self.seed or torch.seed()
60
+ set_global_seed(self.seed)
61
+ else:
62
+ self.seed = None # reset / ignore state
63
+
64
+ transforms = self.transforms if not mask_transform else self.mask_transforms
65
+ for t in transforms:
66
+ x = t(x)
67
+ return x
68
+
69
+
70
+ augmentation_weak = ComposeState([
71
+ T.RandomHorizontalFlip(),
72
+ T.RandomVerticalFlip(),
73
+ RandomRotate90(),
74
+ ])
75
+
76
+
77
+ augmentation_strong = ComposeState([
78
+ T.RandomHorizontalFlip(p=0.5),
79
+ T.RandomVerticalFlip(p=0.5),
80
+ T.RandomApply([T.RandomRotation(90)], p=0.5),
81
+ # (transform, apply_to_mask=True)
82
+ (T.RandomApply([AddGaussianNoise(std=0.0005)], p=0.5), False),
83
+ (T.RandomAdjustSharpness(0.5, p=0.5), False),
84
+ ])
85
+
86
+
87
+ def get_augmentation(type):
88
+ if type == 'none':
89
+ return None
90
+ if type == 'weak':
91
+ return augmentation_weak
92
+ if type == 'strong':
93
+ return augmentation_strong
94
+
95
+
96
+ if __name__ == '__main__':
97
+ import os
98
+ if not os.path.exists('README.md'):
99
+ os.chdir('..')
100
+
101
+ # from utils.debug import debug
102
+ from dataset import get_dataset
103
+ import matplotlib.pyplot as plt
104
+
105
+ dataset = get_dataset('DS') # drone segmentation
106
+ img, mask = dataset[10]
107
+ mask = (mask + 0.2) / 1.2
108
+
109
+ plt.figure(figsize=(14, 8))
110
+ plt.subplot(121)
111
+ plt.imshow(img)
112
+ plt.subplot(122)
113
+ plt.imshow(mask)
114
+ plt.suptitle('no augmentation')
115
+ plt.show()
116
+
117
+ from utils.base import np2torch, torch2np
118
+ img, mask = np2torch(img), np2torch(mask)
119
+
120
+ # from utils.augmentation import get_augmentation
121
+ augmentation = get_augmentation('strong')
122
+
123
+ set_global_seed(1)
124
+
125
+ for i in range(1, 4):
126
+ plt.figure(figsize=(14, 8))
127
+ plt.subplot(121)
128
+ plt.imshow(torch2np(augmentation(img.unsqueeze(0), retain_state=True)).squeeze())
129
+ plt.subplot(122)
130
+ plt.imshow(torch2np(augmentation(mask.unsqueeze(0), mask_transform=True)).squeeze())
131
+ plt.suptitle(f'augmentation test {i}')
132
+ plt.show()
utils/base.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for other scripts
3
+ """
4
+
5
+ import os
6
+ import shutil
7
+
8
+ import random
9
+
10
+ import torch
11
+ import mlflow
12
+ from mlflow.tracking import MlflowClient
13
+ import numpy as np
14
+
15
+ from IPython.display import display, Markdown
16
+
17
+ from b2sdk.v1 import *
18
+
19
+ import argparse
20
+
21
+
22
+ class SmartFormatter(argparse.HelpFormatter):
23
+
24
+ def _split_lines(self, text, width):
25
+ if text.startswith('R|'):
26
+ return text[2:].splitlines()
27
+ # this is the RawTextHelpFormatter._split_lines
28
+ return argparse.HelpFormatter._split_lines(self, text, width)
29
+
30
+
31
+ def str2bool(string):
32
+ return string == 'True'
33
+
34
+
35
+ def np2torch(nparray):
36
+ """Convert numpy array to torch tensor
37
+ For array with more than 3 channels, it is better to use an input array in the format BxHxWxC
38
+
39
+ Args:
40
+ numpy array (ndarray) BxHxWxC
41
+ Returns:
42
+ torch tensor (tensor) BxCxHxW"""
43
+
44
+ tensor = torch.Tensor(nparray)
45
+
46
+ if tensor.ndim == 2:
47
+ return tensor
48
+ if tensor.ndim == 3:
49
+ height, width, channels = tensor.shape
50
+ if channels <= 3: # Single image with more channels (HxWxC)
51
+ return tensor.permute(2, 0, 1)
52
+
53
+ if tensor.ndim == 4: # More images with more channels (BxHxWxC)
54
+ return tensor.permute(0, 3, 1, 2)
55
+
56
+ return tensor
57
+
58
+
59
+ def torch2np(torchtensor):
60
+ """Convert torch tensor to numpy array
61
+ For tensor with more than 3 channels or batch, it is better to use an input tensor in the format BxCxHxW
62
+
63
+ Args:
64
+ torch tensor (tensor) BxCxHxW
65
+ Returns:
66
+ numpy array (ndarray) BxHxWxC"""
67
+
68
+ ndarray = torchtensor.detach().cpu().numpy().astype(np.float32)
69
+
70
+ if ndarray.ndim == 3: # Single image with more channels (CxHxW)
71
+ channels, height, width = ndarray.shape
72
+ if channels <= 3:
73
+ return ndarray.transpose(1, 2, 0)
74
+
75
+ if ndarray.ndim == 4: # More images with more channels (BxCxHxW)
76
+ return ndarray.transpose(0, 2, 3, 1)
77
+
78
+ return ndarray
79
+
80
+
81
+ def set_random_seed(seed):
82
+ np.random.seed(seed) # cpu vars
83
+ torch.manual_seed(seed) # cpu vars
84
+ random.seed(seed) # Python
85
+ if torch.cuda.is_available():
86
+ torch.cuda.manual_seed(seed)
87
+ torch.cuda.manual_seed_all(seed) # gpu vars
88
+ torch.backends.cudnn.deterministic = True # needed
89
+ torch.backends.cudnn.benchmark = False
90
+
91
+
92
+ def normalize(img):
93
+ """Normalize images
94
+
95
+ Args:
96
+ imgs (ndarray): image to normalize --> size: (Height,Width,Channels)
97
+ Returns:
98
+ normalized (ndarray): normalized image
99
+ mu (ndarray): mean
100
+ sigma (ndarray): standard deviation
101
+ """
102
+
103
+ img = img.astype(float)
104
+
105
+ if len(img.shape) == 2:
106
+ img = img[:, :, np.newaxis]
107
+
108
+ height, width, channels = img.shape
109
+
110
+ mu, sigma = np.empty(channels), np.empty(channels)
111
+
112
+ for ch in range(channels):
113
+ temp_mu = img[:, :, ch].mean()
114
+ temp_sigma = img[:, :, ch].std()
115
+
116
+ img[:, :, ch] = (img[:, :, ch] - temp_mu) / (temp_sigma + 1e-4)
117
+
118
+ mu[ch] = temp_mu
119
+ sigma[ch] = temp_sigma
120
+
121
+ return img, mu, sigma
122
+
123
+
124
+ def b2_list_files(folder=''):
125
+ bucket = get_b2_bucket()
126
+ for file_info, _ in bucket.ls(folder, show_versions=False):
127
+ print(file_info.file_name)
128
+
129
+
130
+ def get_b2_bucket():
131
+ bucket_name = 'perturbed-minds'
132
+ application_key_id = '003d6b042de536a0000000008'
133
+ application_key = 'K003HMNxnoa91Dy9c0V8JVCKNUnwR9U'
134
+ info = InMemoryAccountInfo()
135
+ b2_api = B2Api(info)
136
+ b2_api.authorize_account('production', application_key_id, application_key)
137
+ bucket = b2_api.get_bucket_by_name(bucket_name)
138
+ return bucket
139
+
140
+
141
+ def b2_download_folder(b2_dir, local_dir, force_download=False, mirror_folder=True):
142
+ """Downloads a folder from the b2 bucket and optionally cleans
143
+ up files that are no longer on the server
144
+
145
+ Args:
146
+ b2_dir (str): path to folder on the b2 server
147
+ local_dir (str): path to folder on the local machine
148
+ force_download (bool, optional): force the download, if set to `False`,
149
+ files with matching names on the local machine will be skipped
150
+ mirror_folder (bool, optional): if set to `True`, files that are found in
151
+ the local directory, but are not on the server will be deleted
152
+ """
153
+ bucket = get_b2_bucket()
154
+
155
+ if not os.path.exists(local_dir):
156
+ os.makedirs(local_dir)
157
+ elif not force_download:
158
+ return
159
+
160
+ download_files = [file_info.file_name.split(b2_dir + '/')[-1]
161
+ for file_info, _ in bucket.ls(b2_dir, show_versions=False)]
162
+
163
+ for file_name in download_files:
164
+ if file_name.endswith('/.bzEmpty'): # subdirectory, download recursively
165
+ subdir = file_name.replace('/.bzEmpty', '')
166
+ if len(subdir) > 0:
167
+ b2_subdir = os.path.join(b2_dir, subdir)
168
+ local_subdir = os.path.join(local_dir, subdir)
169
+ if b2_subdir != b2_dir:
170
+ b2_download_folder(b2_subdir, local_subdir, force_download=force_download,
171
+ mirror_folder=mirror_folder)
172
+ else: # file
173
+ b2_file = os.path.join(b2_dir, file_name)
174
+ local_file = os.path.join(local_dir, file_name)
175
+ if not os.path.exists(local_file) or force_download:
176
+ print(f"downloading b2://{b2_file} -> {local_file}")
177
+ bucket.download_file_by_name(b2_file, DownloadDestLocalFile(local_file))
178
+
179
+ if mirror_folder: # remove all files that are not on the b2 server anymore
180
+ for i, file in enumerate(download_files):
181
+ if file.endswith('/.bzEmpty'): # subdirectory, download recursively
182
+ download_files[i] = file.replace('/.bzEmpty', '')
183
+ for file_name in os.listdir(local_dir):
184
+ if file_name not in download_files:
185
+ local_file = os.path.join(local_dir, file_name)
186
+ print(f"deleting {local_file}")
187
+ if os.path.isdir(local_file):
188
+ shutil.rmtree(local_file)
189
+ else:
190
+ os.remove(local_file)
191
+
192
+
193
+ def get_name(obj):
194
+ return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
195
+
196
+
197
+ def get_mlflow_model_by_name(experiment_name, run_name,
198
+ tracking_uri="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
199
+ download_model=True):
200
+
201
+ # 0. mlflow basics
202
+ mlflow.set_tracking_uri(tracking_uri)
203
+ os.environ["AWS_ACCESS_KEY_ID"] = "#TODO: add your AWS access key if you want to write your results to our collaborative lab server"
204
+ os.environ["AWS_SECRET_ACCESS_KEY"] = "#TODO: add your AWS seceret access key if you want to write your results to our collaborative lab server"
205
+
206
+ # # 1. use get_experiment_by_name to get experiment objec
207
+ experiment = mlflow.get_experiment_by_name(experiment_name)
208
+
209
+ # # 2. use search_runs with experiment_id for string search query
210
+ if os.path.isfile('cache/runs_names.pkl'):
211
+ runs = pd.read_pickle('cache/runs_names.pkl')
212
+ if runs['tags.mlflow.runName'][runs['tags.mlflow.runName'] == run_name].empty:
213
+ # returns a pandas data frame where each row is a run (if several exist under that name)
214
+ runs = fetch_runs_list_mlflow(experiment)
215
+ else:
216
+ # returns a pandas data frame where each row is a run (if several exist under that name)
217
+ runs = fetch_runs_list_mlflow(experiment)
218
+
219
+ # 3. get the selected run between all runs inside the selected experiment
220
+ run = runs.loc[runs['tags.mlflow.runName'] == run_name]
221
+
222
+ # 4. check if there is only a run with that name
223
+ assert len(run) == 1, "More runs with this name"
224
+ index_run = run.index[0]
225
+ artifact_uri = run.loc[index_run, 'artifact_uri']
226
+
227
+ # 5. load state_dict of your run
228
+ state_dict = mlflow.pytorch.load_state_dict(artifact_uri)
229
+
230
+ # 6. load model of your run
231
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
232
+ # model = mlflow.pytorch.load_model(os.path.join(
233
+ # artifact_uri, "model"), map_location=torch.device(DEVICE))
234
+ model = fetch_from_mlflow(os.path.join(
235
+ artifact_uri, "model"), use_cache=True, download_model=download_model)
236
+
237
+ return state_dict, model
238
+
239
+
240
+ def data_loader_mean_and_std(data_loader, transform=None):
241
+ means = []
242
+ stds = []
243
+ for x, y in data_loader:
244
+ if transform is not None:
245
+ x = transform(x)
246
+ means.append(x.mean(dim=(0, 2, 3)).unsqueeze(0))
247
+ stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
248
+ return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
249
+
250
+
251
+ def fetch_runs_list_mlflow(experiment):
252
+ runs = mlflow.search_runs(experiment.experiment_id)
253
+ runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
254
+ return runs
255
+
256
+
257
+ def fetch_from_mlflow(uri, use_cache=True, download_model=True):
258
+ cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
259
+ if use_cache and os.path.exists(cache_loc):
260
+ print(f'loading cached model from {cache_loc} ...')
261
+ return torch.load(cache_loc)
262
+ else:
263
+ print(f'fetching model from {uri} ...')
264
+ model = mlflow.pytorch.load_model(uri)
265
+ os.makedirs(os.path.dirname(cache_loc), exist_ok=True)
266
+ if download_model:
267
+ torch.save(model, cache_loc, pickle_module=mlflow.pytorch.pickle_module)
268
+ return model
269
+
270
+
271
+ def display_mlflow_run_info(run):
272
+ uri = mlflow.get_tracking_uri()
273
+ experiment_id = run.info.experiment_id
274
+ experiment_name = mlflow.get_experiment(experiment_id).name
275
+ run_id = run.info.run_id
276
+ run_name = run.data.tags['mlflow.runName']
277
+ experiment_url = f'{uri}/#/experiments/{experiment_id}'
278
+ run_url = f'{experiment_url}/runs/{run_id}'
279
+
280
+ print(f'view results at {run_url}')
281
+ display(Markdown(
282
+ f"[<a href='{experiment_url}'>experiment {experiment_id} '{experiment_name}'</a>]"
283
+ f" > "
284
+ f"[<a href='{run_url}'>run '{run_name}' {run_id}</a>]"
285
+ ))
286
+ print('')
287
+
288
+
289
+ def get_train_test_indices_drone(df, frac, seed=None):
290
+ """ Split indices of a DataFrame with binary and balanced labels into balanced subindices
291
+
292
+ Args:
293
+ df (pd.DataFrame): {0,1}-labeled data
294
+ frac (float): fraction of indicies in first subset
295
+ random_seed (int): random seed used as random state in np.random and as argument for random.seed()
296
+ Returns:
297
+ train_indices (torch.tensor): balanced subset of indices corresponding to rows in the DataFrame
298
+ test_indices (torch.tensor): balanced subset of indices corresponding to rows in the DataFrame
299
+ """
300
+
301
+ split_idx = int(len(df) * frac / 2)
302
+ df_with = df[df['label'] == 1]
303
+ df_without = df[df['label'] == 0]
304
+
305
+ np.random.seed(seed)
306
+ df_with_train = df_with.sample(n=split_idx, random_state=seed)
307
+ df_with_test = df_with.drop(df_with_train.index)
308
+
309
+ df_without_train = df_without.sample(n=split_idx, random_state=seed)
310
+ df_without_test = df_without.drop(df_without_train.index)
311
+
312
+ train_indices = list(df_without_train.index) + list(df_with_train.index)
313
+ test_indices = list(df_without_test.index) + list(df_with_test.index)
314
+
315
+ """"
316
+ print('fraction of 1-label in train set: {}'.format(len(df_with_train)/(len(df_with_train) + len(df_without_train))))
317
+ print('fraction of 1-label in test set: {}'.format(len(df_with_test)/(len(df_with_test) + len(df_with_test))))
318
+ """
319
+
320
+ return train_indices, test_indices
321
+
322
+
323
+ def smp_get_loss(loss):
324
+ if loss == "Dice":
325
+ return smp.losses.DiceLoss(mode='binary', from_logits=True)
326
+ if loss == "BCE":
327
+ return nn.BCELoss()
328
+ elif loss == "BCEWithLogits":
329
+ return smp.losses.BCEWithLogitsLoss()
330
+ elif loss == "DicyBCE":
331
+ from pytorch_toolbelt import losses as ptbl
332
+ return ptbl.JointLoss(ptbl.DiceLoss(mode='binary', from_logits=False),
333
+ nn.BCELoss(),
334
+ first_weight=args.dice_weight,
335
+ second_weight=args.bce_weight)
utils/dataset_utils.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import random
3
+ import numpy as np
4
+
5
+ import torch
6
+
7
+ from skimage.util.shape import view_as_windows
8
+
9
+
10
+ def load_image(path):
11
+ file_type = path.split('.')[-1].lower()
12
+ if file_type == 'dng':
13
+ img = rawpy.imread(path).raw_image_visible
14
+ elif file_type == 'tiff' or file_type == 'tif':
15
+ img = np.array(tiff.imread(path), dtype=np.float32)
16
+ else:
17
+ img = np.array(Image.open(path), dtype=np.float32)
18
+ return img
19
+
20
+
21
+ def list_images_in_dir(path):
22
+ image_list = [os.path.join(path, img_name)
23
+ for img_name in sorted(os.listdir(path))
24
+ if img_name.split('.')[-1].lower() in IMAGE_FILE_TYPES]
25
+ return image_list
26
+
27
+
28
+ def k_fold(dataset, n_splits: int, seed: int, train_size: float):
29
+ """Split dataset in subsets for cross-validation
30
+
31
+ Args:
32
+ dataset (class): dataset to split
33
+ n_split (int): Number of re-shuffling & splitting iterations.
34
+ seed (int): seed for k_fold splitting
35
+ train_size (float): should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split.
36
+ Returns:
37
+ idxs (list): indeces for splitting the dataset. The list contain n_split pair of train/test indeces.
38
+ """
39
+ if hasattr(dataset, 'labels'):
40
+ x = dataset.images
41
+ y = dataset.labels
42
+ elif hasattr(dataset, 'masks'):
43
+ x = dataset.images
44
+ y = dataset.masks
45
+
46
+ idxs = []
47
+
48
+ if dataset.task == 'classification':
49
+ sss = StratifiedShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=seed)
50
+
51
+ for idxs_train, idxs_test in sss.split(x, y):
52
+ idxs.append((idxs_train.tolist(), idxs_test.tolist()))
53
+
54
+ elif dataset.task == 'segmentation':
55
+ for n in range(n_splits):
56
+ split_idx = int(len(dataset) * train_size)
57
+ indices = np.random.permutation(len(dataset))
58
+ idxs.append((indices[:split_idx].tolist(), indices[split_idx:].tolist()))
59
+
60
+ return idxs
61
+
62
+
63
+ def split_img(imgs, ROIs=(3, 3), step=(1, 1)):
64
+ """Split the imgs in regions of size ROIs.
65
+
66
+ Args:
67
+ imgs (ndarray): images which you want to split
68
+ ROIs (tuple): size of sub-regions splitted (ROIs=region of interests)
69
+ step (tuple): step path from one sub-region to the next one (in the x,y axis)
70
+
71
+ Returns:
72
+ ndarray: splitted subimages.
73
+ The size is (x_num_subROIs*y_num_subROIs, **) where:
74
+ x_num_subROIs = ( imgs.shape[1]-int(ROIs[1]/2)*2 )/step[1]
75
+ y_num_subROIs = ( imgs.shape[0]-int(ROIs[0]/2)*2 )/step[0]
76
+
77
+ Example:
78
+ >>> from dataset_generator import split
79
+ >>> imgs_splitted = split(imgs, ROI_size = (5,5), step=(2,3))
80
+ """
81
+
82
+ if len(ROIs) > 2:
83
+ return print("ROIs is a 2 element list")
84
+
85
+ if len(step) > 2:
86
+ return print("step is a 2 element list")
87
+
88
+ if type(imgs) != type(np.array(1)):
89
+ return print("imgs should be a ndarray")
90
+
91
+ if len(imgs.shape) == 2: # Single image with one channel (HxW)
92
+ splitted = view_as_windows(imgs, (ROIs[0], ROIs[1]), (step[0], step[1]))
93
+ return splitted.reshape((-1, ROIs[0], ROIs[1]))
94
+
95
+ if len(imgs.shape) == 3:
96
+ _, _, channels = imgs.shape
97
+ if channels <= 3: # Single image more channels (HxWxC)
98
+ splitted = view_as_windows(imgs, (ROIs[0], ROIs[1], channels), (step[0], step[1], channels))
99
+ return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
100
+ else: # More images with 1 channel
101
+ splitted = view_as_windows(imgs, (1, ROIs[0], ROIs[1]), (1, step[0], step[1]))
102
+ return splitted.reshape((-1, ROIs[0], ROIs[1]))
103
+
104
+ if len(imgs.shape) == 4: # More images with more channels(BxHxWxC)
105
+ _, _, _, channels = imgs.shape
106
+ splitted = view_as_windows(imgs, (1, ROIs[0], ROIs[1], channels), (1, step[0], step[1], channels))
107
+ return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
108
+
109
+
110
+ def join_blocks(splitted, final_shape):
111
+ """Join blocks to reobtain a splitted image
112
+
113
+ Attribute:
114
+ splitted (tensor) = image splitted in blocks, size = (N_blocks, Channels, Height, Width)
115
+ final_shape (tuple) = size of the final image reconstructed (Height, Width)
116
+ Return:
117
+ tensor: image restored from blocks. size = (Channels, Height, Width)
118
+
119
+ """
120
+ n_blocks, channels, ROI_height, ROI_width = splitted.shape
121
+
122
+ rows = final_shape[0] // ROI_height
123
+ columns = final_shape[1] // ROI_width
124
+
125
+ final_img = torch.empty(rows, channels, ROI_height, ROI_width * columns)
126
+ for r in range(rows):
127
+ stackblocks = splitted[r * columns]
128
+ for c in range(1, columns):
129
+ stackblocks = torch.cat((stackblocks, splitted[r * columns + c]), axis=2)
130
+ final_img[r] = stackblocks
131
+
132
+ joined_img = final_img[0]
133
+
134
+ for i in np.arange(1, len(final_img)):
135
+ joined_img = torch.cat((joined_img, final_img[i]), axis=1)
136
+
137
+ return joined_img
138
+
139
+
140
+ def random_ROI(X, Y, ROIs=(512, 512)):
141
+ """ Return a random region for each input/target pair images of the dataset
142
+ Args:
143
+ Y (ndarray): target of your dataset --> size: (BxHxWxC)
144
+ X (ndarray): input of your dataset --> size: (BxHxWxC)
145
+ ROIs (tuple): size of random region (ROIs=region of interests)
146
+
147
+ Returns:
148
+ For each pair images (input/target) of the dataset, return respectively random ROIs
149
+ Y_cut (ndarray): target of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
150
+ X_cut (ndarray): input of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
151
+
152
+ Example:
153
+ >>> from dataset_generator import random_ROI
154
+ >>> X,Y = random_ROI(X,Y, ROIs = (10,10))
155
+ """
156
+
157
+ batch, channels, height, width = X.shape
158
+
159
+ X_cut = np.empty((batch, ROIs[0], ROIs[1], channels))
160
+ Y_cut = np.empty((batch, ROIs[0], ROIs[1], channels))
161
+
162
+ for i in np.arange(len(X)):
163
+ x_size = int(random.random() * (height - (ROIs[0] + 1)))
164
+ y_size = int(random.random() * (width - (ROIs[1] + 1)))
165
+ X_cut[i] = X[i, x_size:x_size + ROIs[0], y_size:y_size + ROIs[1], :]
166
+ Y_cut[i] = Y[i, x_size:x_size + ROIs[0], y_size:y_size + ROIs[1], :]
167
+ return X_cut, Y_cut
168
+
169
+
170
+ def one2many_random_ROI(X, Y, datasize=1000, ROIs=(512, 512)):
171
+ """ Return a dataset of N subimages obtained from random regions of the same image
172
+ Args:
173
+ Y (ndarray): target of your dataset --> size: (1,H,W,C)
174
+ X (ndarray): input of your dataset --> size: (1,H,W,C)
175
+ datasize = number of random ROIs to generate
176
+ ROIs (tuple): size of random region (ROIs=region of interests)
177
+
178
+ Returns:
179
+ Y_cut (ndarray): target of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
180
+ X_cut (ndarray): input of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
181
+ """
182
+
183
+ batch, channels, height, width = X.shape
184
+
185
+ X_cut = np.empty((datasize, ROIs[0], ROIs[1], channels))
186
+ Y_cut = np.empty((datasize, ROIs[0], ROIs[1], channels))
187
+
188
+ for i in np.arange(datasize):
189
+ X_cut[i], Y_cut[i] = random_ROI(X, Y, ROIs)
190
+ return X_cut, Y_cut
utils/hendrycks_robustness.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Code extracted from the paper:
3
+
4
+ @articlehendrycks2019robustness,
5
+ title=Benchmarking Neural Network Robustness to Common Corruptions and Perturbations,
6
+ author=Dan Hendrycks and Thomas Dietterich,
7
+ journal=Proceedings of the International Conference on Learning Representations,
8
+ year=2019
9
+ }
10
+
11
+ The code is modified to fit with our model
12
+ '''
13
+
14
+ import os
15
+ from PIL import Image
16
+ import os.path
17
+ import time
18
+ import torch
19
+ import torchvision.datasets as dset
20
+ import torchvision.transforms as trn
21
+ import torch.utils.data as data
22
+ import numpy as np
23
+
24
+ from PIL import Image
25
+
26
+
27
+ # /////////////// Distortion Helpers ///////////////
28
+
29
+ import skimage as sk
30
+ from skimage.filters import gaussian
31
+ from io import BytesIO
32
+ from wand.image import Image as WandImage
33
+ from wand.api import library as wandlibrary
34
+ import wand.color as WandColor
35
+ import ctypes
36
+ from PIL import Image as PILImage
37
+ import cv2
38
+ from scipy.ndimage import zoom as scizoom
39
+ from scipy.ndimage.interpolation import map_coordinates
40
+ import warnings
41
+
42
+ warnings.simplefilter("ignore", UserWarning)
43
+
44
+
45
+ def disk(radius, alias_blur=0.1, dtype=np.float32):
46
+ if radius <= 8:
47
+ L = np.arange(-8, 8 + 1)
48
+ ksize = (3, 3)
49
+ else:
50
+ L = np.arange(-radius, radius + 1)
51
+ ksize = (5, 5)
52
+ X, Y = np.meshgrid(L, L)
53
+ aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype)
54
+ aliased_disk /= np.sum(aliased_disk)
55
+
56
+ # supersample disk to antialias
57
+ return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
58
+
59
+
60
+ # Tell Python about the C method
61
+ wandlibrary.MagickMotionBlurImage.argtypes = (ctypes.c_void_p, # wand
62
+ ctypes.c_double, # radius
63
+ ctypes.c_double, # sigma
64
+ ctypes.c_double) # angle
65
+
66
+
67
+ # Extend wand.image.Image class to include method signature
68
+ class MotionImage(WandImage):
69
+ def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
70
+ wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
71
+
72
+
73
+ # modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py
74
+ def plasma_fractal(mapsize=32, wibbledecay=3):
75
+ """
76
+ Generate a heightmap using diamond-square algorithm.
77
+ Return square 2d array, side length 'mapsize', of floats in range 0-255.
78
+ 'mapsize' must be a power of two.
79
+ """
80
+ assert (mapsize & (mapsize - 1) == 0)
81
+ maparray = np.empty((mapsize, mapsize), dtype=np.float_)
82
+ maparray[0, 0] = 0
83
+ stepsize = mapsize
84
+ wibble = 100
85
+
86
+ def wibbledmean(array):
87
+ return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape)
88
+
89
+ def fillsquares():
90
+ """For each square of points stepsize apart,
91
+ calculate middle value as mean of points + wibble"""
92
+ cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
93
+ squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
94
+ squareaccum += np.roll(squareaccum, shift=-1, axis=1)
95
+ maparray[stepsize // 2:mapsize:stepsize,
96
+ stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
97
+
98
+ def filldiamonds():
99
+ """For each diamond of points stepsize apart,
100
+ calculate middle value as mean of points + wibble"""
101
+ mapsize = maparray.shape[0]
102
+ drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
103
+ ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
104
+ ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
105
+ lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
106
+ ltsum = ldrsum + lulsum
107
+ maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
108
+ tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
109
+ tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
110
+ ttsum = tdrsum + tulsum
111
+ maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)
112
+
113
+ while stepsize >= 2:
114
+ fillsquares()
115
+ filldiamonds()
116
+ stepsize //= 2
117
+ wibble /= wibbledecay
118
+
119
+ maparray -= maparray.min()
120
+ return maparray / maparray.max()
121
+
122
+
123
+ def clipped_zoom(img, zoom_factor):
124
+ h = img.shape[0]
125
+ # ceil crop height(= crop width)
126
+ ch = int(np.ceil(h / zoom_factor))
127
+
128
+ top = (h - ch) // 2
129
+ img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
130
+ # trim off any extra pixels
131
+ trim_top = (img.shape[0] - h) // 2
132
+
133
+ return img[trim_top:trim_top + h, trim_top:trim_top + h]
134
+
135
+
136
+ # /////////////// End Distortion Helpers ///////////////
137
+
138
+
139
+ # /////////////// Distortions ///////////////
140
+
141
+ class Distortions:
142
+ def __init__(self, severity=1, transform='identity'):
143
+ self.severity = severity
144
+ self.transform = transform
145
+
146
+ def __call__(self, img):
147
+ assert torch.is_tensor(img), 'Input data need to be a torch.tensor'
148
+ assert len(img.shape) == 3, 'Input image should be RGB'
149
+ img = self.torch2np(img)
150
+ t = getattr(self, self.transform)
151
+ img = t(img, self.severity)
152
+ return self.np2torch(img).float()
153
+
154
+ def np2torch(self,x):
155
+ return torch.tensor(x).permute(2,0,1)
156
+
157
+ def torch2np(self,x):
158
+ return np.array(x.permute(1,2,0))
159
+
160
+ def identity(self,x, severity=1):
161
+ return x
162
+
163
+ def gaussian_noise(self, x, severity=1):
164
+ c = [0.04, 0.06, .08, .09, .10][severity - 1]
165
+ return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1)
166
+
167
+
168
+ def shot_noise(self, x, severity=1):
169
+ c = [500, 250, 100, 75, 50][severity - 1]
170
+ return np.clip(np.random.poisson(x * c) / c, 0, 1)
171
+
172
+
173
+ def impulse_noise(self, x, severity=1):
174
+ c = [.01, .02, .03, .05, .07][severity - 1]
175
+
176
+ x = sk.util.random_noise(x, mode='s&p', amount=c)
177
+ return np.clip(x, 0, 1)
178
+
179
+
180
+ def speckle_noise(self, x, severity=1):
181
+ c = [.06, .1, .12, .16, .2][severity - 1]
182
+ return np.clip(x + x * np.random.normal(size=x.shape, scale=c), 0, 1)
183
+
184
+
185
+ def gaussian_blur(self, x, severity=1):
186
+ c = [.4, .6, 0.7, .8, 1][severity - 1]
187
+
188
+ x = gaussian(x, sigma=c, multichannel=True)
189
+ return np.clip(x, 0, 1)
190
+
191
+
192
+ def glass_blur(self, x, severity=1):
193
+ # sigma, max_delta, iterations
194
+ c = [(0.05,1,1), (0.25,1,1), (0.4,1,1), (0.25,1,2), (0.4,1,2)][severity - 1]
195
+
196
+ x = gaussian(x, sigma=c[0], multichannel=True)
197
+
198
+ # locally shuffle pixels
199
+ for i in range(c[2]):
200
+ for h in range(32 - c[1], c[1], -1):
201
+ for w in range(32 - c[1], c[1], -1):
202
+ dx, dy = np.random.randint(-c[1], c[1], size=(2,))
203
+ h_prime, w_prime = h + dy, w + dx
204
+ # swap
205
+ x[h, w], x[h_prime, w_prime] = x[h_prime, w_prime], x[h, w]
206
+
207
+ return np.clip(gaussian(x, sigma=c[0], multichannel=True), 0, 1)
208
+
209
+
210
+ def defocus_blur(self, x, severity=1):
211
+ c = [(0.3, 0.4), (0.4, 0.5), (0.5, 0.6), (1, 0.2), (1.5, 0.1)][severity - 1]
212
+ kernel = disk(radius=c[0], alias_blur=c[1])
213
+
214
+ channels = []
215
+ for d in range(3):
216
+ channels.append(cv2.filter2D(x[:, :, d], -1, kernel))
217
+ channels = np.array(channels).transpose((1, 2, 0)) # 3x32x32 -> 32x32x3
218
+
219
+ return np.clip(channels, 0, 1)
220
+
221
+
222
+ def motion_blur(self, x, severity=1):
223
+ c = [(6,1), (6,1.5), (6,2), (8,2), (9,2.5)][severity - 1]
224
+
225
+ output = BytesIO()
226
+ x.save(output, format='PNG')
227
+ x = MotionImage(blob=output.getvalue())
228
+
229
+ x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))
230
+
231
+ x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8),
232
+ cv2.IMREAD_UNCHANGED)
233
+
234
+ if x.shape != (32, 32):
235
+ return np.clip(x[..., [2, 1, 0]], 0, 1) # BGR to RGB
236
+ else: # greyscale to RGB
237
+ return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 1)
238
+
239
+
240
+ def zoom_blur(self, x, severity=1):
241
+ c = [np.arange(1, 1.06, 0.01), np.arange(1, 1.11, 0.01), np.arange(1, 1.16, 0.01),
242
+ np.arange(1, 1.21, 0.01), np.arange(1, 1.26, 0.01)][severity - 1]
243
+ out = np.zeros_like(x)
244
+ for zoom_factor in c:
245
+ out += clipped_zoom(x, zoom_factor)
246
+
247
+ x = (x + out) / (len(c) + 1)
248
+ return np.clip(x, 0, 1)
249
+
250
+
251
+ def fog(self, x, severity=1):
252
+ c = [(.2,3), (.5,3), (0.75,2.5), (1,2), (1.5,1.75)][severity - 1]
253
+ max_val = x.max()
254
+ x += c[0] * plasma_fractal(wibbledecay=c[1])[:32, :32][..., np.newaxis]
255
+ return np.clip(x * max_val / (max_val + c[0]), 0, 1)
256
+
257
+
258
+ def frost(self, x, severity=1):
259
+ c = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][severity - 1]
260
+ idx = np.random.randint(5)
261
+ filename = ['./frost1.png', './frost2.png', './frost3.png', './frost4.jpg', './frost5.jpg', './frost6.jpg'][idx]
262
+ frost = cv2.imread(filename)
263
+ frost = cv2.resize(frost, (0, 0), fx=0.2, fy=0.2)
264
+ # randomly crop and convert to rgb
265
+ x_start, y_start = np.random.randint(0, frost.shape[0] - 32), np.random.randint(0, frost.shape[1] - 32)
266
+ frost = frost[x_start:x_start + 32, y_start:y_start + 32][..., [2, 1, 0]]
267
+
268
+ return np.clip(c[0] * np.array(x) + c[1] * frost, 0, 1)
269
+
270
+
271
+ def snow(self, x, severity=1):
272
+ c = [(0.1,0.2,1,0.6,8,3,0.95),
273
+ (0.1,0.2,1,0.5,10,4,0.9),
274
+ (0.15,0.3,1.75,0.55,10,4,0.9),
275
+ (0.25,0.3,2.25,0.6,12,6,0.85),
276
+ (0.3,0.3,1.25,0.65,14,12,0.8)][severity - 1]
277
+
278
+ snow_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome
279
+
280
+ snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
281
+ snow_layer[snow_layer < c[3]] = 0
282
+
283
+ snow_layer = PILImage.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
284
+ output = BytesIO()
285
+ snow_layer.save(output, format='PNG')
286
+ snow_layer = MotionImage(blob=output.getvalue())
287
+
288
+ snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=np.random.uniform(-135, -45))
289
+
290
+ snow_layer = cv2.imdecode(np.fromstring(snow_layer.make_blob(), np.uint8),
291
+ cv2.IMREAD_UNCHANGED) / (2**16-1)
292
+ snow_layer = snow_layer[..., np.newaxis]
293
+
294
+ x = c[6] * x + (1 - c[6]) * np.maximum(x, cv2.cvtColor(x, cv2.COLOR_RGB2GRAY).reshape(32, 32, 1) * 1.5 + 0.5)
295
+ return np.clip(x + snow_layer + np.rot90(snow_layer, k=2), 0, 1)
296
+
297
+
298
+ def spatter(self, x, severity=1):
299
+ c = [(0.62,0.1,0.7,0.7,0.5,0),
300
+ (0.65,0.1,0.8,0.7,0.5,0),
301
+ (0.65,0.3,1,0.69,0.5,0),
302
+ (0.65,0.1,0.7,0.69,0.6,1),
303
+ (0.65,0.1,0.5,0.68,0.6,1)][severity - 1]
304
+
305
+ liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])
306
+
307
+ liquid_layer = gaussian(liquid_layer, sigma=c[2])
308
+ liquid_layer[liquid_layer < c[3]] = 0
309
+ if c[5] == 0:
310
+ liquid_layer = (liquid_layer * (2**16-1)).astype(np.uint8)
311
+ dist = (2**16-1) - cv2.Canny(liquid_layer, 50, 150)
312
+ dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
313
+ _, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
314
+ dist = cv2.blur(dist, (3, 3)).astype(np.uint8)
315
+ dist = cv2.equalizeHist(dist)
316
+ # ker = np.array([[-1,-2,-3],[-2,0,0],[-3,0,1]], dtype=np.float32)
317
+ # ker -= np.mean(ker)
318
+ ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
319
+ dist = cv2.filter2D(dist, cv2.CV_8U, ker)
320
+ dist = cv2.blur(dist, (3, 3)).astype(np.float32)
321
+
322
+ m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)
323
+ m /= np.max(m, axis=(0, 1))
324
+ m *= c[4]
325
+
326
+ # water is pale turqouise
327
+ color = np.concatenate((175 / 255. * np.ones_like(m[..., :1]),
328
+ 238 / 255. * np.ones_like(m[..., :1]),
329
+ 238 / 255. * np.ones_like(m[..., :1])), axis=2)
330
+
331
+ color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA)
332
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA)
333
+
334
+ return cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * (2**16-1)
335
+ else:
336
+ m = np.where(liquid_layer > c[3], 1, 0)
337
+ m = gaussian(m.astype(np.float32), sigma=c[4])
338
+ m[m < 0.8] = 0
339
+ # m = np.abs(m) ** (1/c[4])
340
+
341
+ # mud brown
342
+ color = np.concatenate((63 / 255. * np.ones_like(x[..., :1]),
343
+ 42 / 255. * np.ones_like(x[..., :1]),
344
+ 20 / 255. * np.ones_like(x[..., :1])), axis=2)
345
+
346
+ color *= m[..., np.newaxis]
347
+ x *= (1 - m[..., np.newaxis])
348
+
349
+ return np.clip(x + color, 0, 1)
350
+
351
+
352
+ def contrast(self, x, severity=1):
353
+ c = [.75, .5, .4, .3, 0.15][severity - 1]
354
+ means = np.mean(x, axis=(0, 1), keepdims=True)
355
+ return np.clip((x - means) * c + means, 0, 1)
356
+
357
+
358
+ def brightness(self, x, severity=1):
359
+ c = [.05, .1, .15, .2, .3][severity - 1]
360
+
361
+ x = sk.color.rgb2hsv(x)
362
+ x[:, :, 2] = np.clip(x[:, :, 2] + c, 0, 1)
363
+ x = sk.color.hsv2rgb(x)
364
+
365
+ return np.clip(x, 0, 1)
366
+
367
+
368
+ def saturate(self, x, severity=1):
369
+ c = [(0.3, 0), (0.1, 0), (1.5, 0), (2, 0.1), (2.5, 0.2)][severity - 1]
370
+
371
+ x = sk.color.rgb2hsv(x)
372
+ x[:, :, 1] = np.clip(x[:, :, 1] * c[0] + c[1], 0, 1)
373
+ x = sk.color.hsv2rgb(x)
374
+
375
+ return np.clip(x, 0, 1)
376
+
377
+
378
+ def jpeg_compression(self, x, severity=1):
379
+ c = [80, 65, 58, 50, 40][severity - 1]
380
+
381
+ output = BytesIO()
382
+ x.save(output, 'JPEG', quality=c)
383
+ x = PILImage.open(output)
384
+
385
+ return x
386
+
387
+
388
+ def pixelate(self, x, severity=1):
389
+ c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
390
+
391
+ x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
392
+ x = x.resize((32, 32), PILImage.BOX)
393
+
394
+ return x
395
+
396
+
397
+ # mod of https://gist.github.com/erniejunior/601cdf56d2b424757de5
398
+ def elastic_transform(self, image, severity=1):
399
+ IMSIZE = 32
400
+ c = [(IMSIZE*0, IMSIZE*0, IMSIZE*0.08),
401
+ (IMSIZE*0.05, IMSIZE*0.2, IMSIZE*0.07),
402
+ (IMSIZE*0.08, IMSIZE*0.06, IMSIZE*0.06),
403
+ (IMSIZE*0.1, IMSIZE*0.04, IMSIZE*0.05),
404
+ (IMSIZE*0.1, IMSIZE*0.03, IMSIZE*0.03)][severity - 1]
405
+
406
+ shape = image.shape
407
+ shape_size = shape[:2]
408
+
409
+ # random affine
410
+ center_square = np.float32(shape_size) // 2
411
+ square_size = min(shape_size) // 3
412
+ pts1 = np.float32([center_square + square_size,
413
+ [center_square[0] + square_size, center_square[1] - square_size],
414
+ center_square - square_size])
415
+ pts2 = pts1 + np.random.uniform(-c[2], c[2], size=pts1.shape).astype(np.float32)
416
+ M = cv2.getAffineTransform(pts1, pts2)
417
+ image = cv2.warpAffine(image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101)
418
+
419
+ dx = (gaussian(np.random.uniform(-1, 1, size=shape[:2]),
420
+ c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32)
421
+ dy = (gaussian(np.random.uniform(-1, 1, size=shape[:2]),
422
+ c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32)
423
+ dx, dy = dx[..., np.newaxis], dy[..., np.newaxis]
424
+
425
+ x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]))
426
+ indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1))
427
+ return np.clip(map_coordinates(image, indices, order=1, mode='reflect').reshape(shape), 0, 1)
428
+
429
+ if __name__=='__main__':
430
+ import os
431
+
432
+ import numpy as np
433
+ import matplotlib.pyplot as plt
434
+ import tifffile as tiff
435
+ import torch
436
+
437
+ os.system('cd ..')
438
+
439
+ img = tiff.imread('/home/marco/perturbed-minds/perturbed-minds/data/microscopy/images/rgb_scale100/Ma190c_lame1_zone1_composite_Mcropped_1.tiff')
440
+ img = np.array(img)/(2**16-1)
441
+ img = torch.tensor(img).permute(2,0,1)
442
+
443
+ def identity(x, sev):
444
+ return x
445
+
446
+ if not os.path.exists('results/Cimages'):
447
+ os.makedirs('results/Cimages')
448
+
449
+ transformations = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
450
+ 'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
451
+
452
+ # glass_blur, defocus_blur, motion_blur, fog, frost, snow, spatter, jpeg_compression, pixelate,
453
+
454
+ plt.figure()
455
+ plt.imshow(img.permute(1,2,0))
456
+ plt.title('identity')
457
+ plt.show()
458
+ plt.savefig(f'results/Cimages/1_identity.png')
459
+
460
+
461
+ for i,t in enumerate(transformations):
462
+
463
+ fig = plt.figure(figsize=(25,5))
464
+ columns = 5
465
+ rows = 1
466
+
467
+ for sev in range(1,6):
468
+ dist = Distortions(severity=sev, transform=t)
469
+ fig.add_subplot(rows, columns, sev)
470
+ plt.imshow(dist(img).permute(1,2,0))
471
+ plt.title(f'{t} {sev}')
472
+ plt.xticks([], [])
473
+ plt.yticks([], [])
474
+ plt.show()
475
+ plt.savefig(f'results/Cimages/{i+2}_{t}.png')
utils/ssim.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """https://github.com/Po-Hsun-Su/pytorch-ssim"""
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Variable
6
+ import numpy as np
7
+ from math import exp
8
+
9
+ def gaussian(window_size, sigma):
10
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
11
+ return gauss/gauss.sum()
12
+
13
+ def create_window(window_size, channel):
14
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
16
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
17
+ return window
18
+
19
+ def _ssim(img1, img2, window, window_size, channel, size_average = True):
20
+ mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
21
+ mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
22
+
23
+ mu1_sq = mu1.pow(2)
24
+ mu2_sq = mu2.pow(2)
25
+ mu1_mu2 = mu1*mu2
26
+
27
+ sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
28
+ sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
29
+ sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
30
+
31
+ C1 = 0.01**2
32
+ C2 = 0.03**2
33
+
34
+ ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
35
+
36
+ if size_average:
37
+ return ssim_map.mean()
38
+ else:
39
+ return ssim_map.mean(1).mean(1).mean(1)
40
+
41
+ class SSIM(torch.nn.Module):
42
+ def __init__(self, window_size = 11, size_average = True):
43
+ super(SSIM, self).__init__()
44
+ self.window_size = window_size
45
+ self.size_average = size_average
46
+ self.channel = 1
47
+ self.window = create_window(window_size, self.channel)
48
+
49
+ def forward(self, img1, img2):
50
+ (_, channel, _, _) = img1.size()
51
+
52
+ if channel == self.channel and self.window.data.type() == img1.data.type():
53
+ window = self.window
54
+ else:
55
+ window = create_window(self.window_size, channel)
56
+
57
+ if img1.is_cuda:
58
+ window = window.cuda(img1.get_device())
59
+ window = window.type_as(img1)
60
+
61
+ self.window = window
62
+ self.channel = channel
63
+
64
+
65
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
66
+
67
+ def ssim(img1, img2, window_size = 11, size_average = True):
68
+ (_, channel, _, _) = img1.size()
69
+ window = create_window(window_size, channel)
70
+
71
+ if img1.is_cuda:
72
+ window = window.cuda(img1.get_device())
73
+ window = window.type_as(img1)
74
+
75
+ return _ssim(img1, img2, window, window_size, channel, size_average)