willis
commited on
Commit
·
290ca27
0
Parent(s):
Initial commit
Browse files- .gitattributes +1 -0
- LICENSE +21 -0
- README.md +129 -0
- dataset.py +573 -0
- demo-files/car.png +0 -0
- demo-files/micro.png +0 -0
- demo.py +54 -0
- environment.yml +363 -0
- figures/ABtesting.py +831 -0
- figures/figure1.sh +7 -0
- figures/figure2.sh +4 -0
- figures/figures.py +92 -0
- figures/train.sh +81 -0
- model.py +305 -0
- processing/pipeline_numpy.py +329 -0
- processing/pipeline_torch.py +314 -0
- requirements.txt +18 -0
- train.py +426 -0
- utils/augmentation.py +132 -0
- utils/base.py +335 -0
- utils/dataset_utils.py +190 -0
- utils/hendrycks_robustness.py +475 -0
- utils/ssim.py +75 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.ipynb filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 aiaudit.org
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[](https://github.com/tterb/atomic-design-ui/blob/master/LICENSEs) [](https://doi.org/10.5281/zenodo.5235536)
|
2 |
+
|
3 |
+
# From Lens to Logit - Addressing Camera Hardware-Drift Using Raw Sensor Data
|
4 |
+
|
5 |
+
*This repository hosts the code for the project ["From Lens to Logit: Addressing Camera Hardware-Drift Using Raw Sensor Data"](https://openreview.net/forum?id=DRAywM1BhU), submitted to the NeurIPS 2021 Datasets and Benchmarks Track.*
|
6 |
+
|
7 |
+
In order to address camera hardware-drift we require two ingredients: raw sensor data and an image processing model. This code repository contains the materials for the second ingredient, the image processing model, as well as scripts to load lada and run experiments. For a conceptual overview of the project we reocommend the [project site](https://aiaudit.org/lens2logit/) or the [full paper](https://openreview.net/forum?id=DRAywM1BhU).
|
8 |
+
|
9 |
+
## A short introduction
|
10 |
+

|
11 |
+
|
12 |
+
|
13 |
+
To create an image, raw sensor data traverses complex image signal processing pipelines. These pipelines are used by cameras and scientific instruments to produce the images fed into machine learning systems. The processing pipelines vary by device, influencing the resulting image statistics and ultimately contributing to what is known as hardware-drift. However, this processing is rarely considered in machine learning modelling, because available benchmark data sets are generally not in raw format. Here we show that pairing qualified raw sensor data with an explicit, differentiable model of the image processing pipeline allows to tackle camera hardware-drift.
|
14 |
+
|
15 |
+
Specifically, we demonstrate
|
16 |
+
1. the **controlled synthesis of hardware-drift test cases**
|
17 |
+
2. modular **hardware-drift forensics**, as well as
|
18 |
+
3. **image processing customization**.
|
19 |
+
|
20 |
+
We make available two data sets.
|
21 |
+
1. **Raw-Microscopy**, contains
|
22 |
+
* **940 raw bright-field microscopy images** of human blood smear slides for leukocyte classification alongside
|
23 |
+
* **5,640 variations measured at six different intensities** and twelve additional sets totalling
|
24 |
+
* **11,280 images of the raw sensor data processed through different pipelines**.
|
25 |
+
3. **Raw-Drone**, contains
|
26 |
+
* **548 raw drone camera images** for car segmentation, alongside
|
27 |
+
* **3,288 variations measured at six different intensities** and also twelve additional sets totalling
|
28 |
+
* **6,576 images of the raw sensor data processed through different pipelines**.
|
29 |
+
## Data access
|
30 |
+
If you use our code you can use the convenient cloud storage integration. Data will be loaded automatically from a cloud storage bucket and stored to your working machine. You can find the code snippet doing that [here](https://github.com/aiaudit-org/lens2logit/blob/f8a165a0c094456f68086167f0bef14c3b311a4e/utils/base.py#L130)
|
31 |
+
|
32 |
+
```python
|
33 |
+
def get_b2_bucket():
|
34 |
+
bucket_name = 'perturbed-minds'
|
35 |
+
application_key_id = '003d6b042de536a0000000008'
|
36 |
+
application_key = 'K003HMNxnoa91Dy9c0V8JVCKNUnwR9U'
|
37 |
+
info = InMemoryAccountInfo()
|
38 |
+
b2_api = B2Api(info)
|
39 |
+
b2_api.authorize_account('production', application_key_id, application_key)
|
40 |
+
bucket = b2_api.get_bucket_by_name(bucket_name)
|
41 |
+
return bucket
|
42 |
+
```
|
43 |
+
We also maintain a copy of the entire dataset with a permanent identifier at Zenodo which you can find here [](https://doi.org/10.5281/zenodo.5235536).
|
44 |
+
## Code
|
45 |
+
### Dependencies
|
46 |
+
#### Conda environment and dependencies
|
47 |
+
To run this code out-of-the-box you can install the latest project conda environment stored in `environment.yml`
|
48 |
+
```console
|
49 |
+
$ conda env create -f environment.yml
|
50 |
+
```
|
51 |
+
#### segmentation_models_pytorch newest version
|
52 |
+
We noticed that PyPi package for `segmentation_models_pytorch` is sometimes behind the project's github repository. If you encounter `smp` related problems we reccomend installing directly from the `smp` reposiroty via
|
53 |
+
```console
|
54 |
+
$ python -m pip install git+https://github.com/qubvel/segmentation_models.pytorch
|
55 |
+
```
|
56 |
+
### Recreate experiments
|
57 |
+
The central file for using the **Lens2Logit** framework for experiments as in the paper is `train.py` which provides a rich set of arguments to experiment with raw image data, different image processing models and task models for regression or classification. Below we provide three example prompts for the type of experiments reported in the [paper](https://openreview.net/forum?id=DRAywM1BhU)
|
58 |
+
#### Controlled synthesis of hardware-drift test cases
|
59 |
+
```console
|
60 |
+
$ python train.py \
|
61 |
+
--experiment_name YOUR-EXPERIMENT-NAME \
|
62 |
+
--run_name YOUR-RUN-NAME \
|
63 |
+
--dataset Microscopy \
|
64 |
+
--lr 1e-5 \
|
65 |
+
--n_splits 5 \
|
66 |
+
--epochs 5 \
|
67 |
+
--classifier_pretrained \
|
68 |
+
--processing_mode static \
|
69 |
+
--augmentation weak \
|
70 |
+
--log_model True \
|
71 |
+
--iso 0.01 \
|
72 |
+
--freeze_processor \
|
73 |
+
--processor_uri "$processor_uri" \
|
74 |
+
--track_processing \
|
75 |
+
--track_every_epoch \
|
76 |
+
--track_predictions \
|
77 |
+
--track_processing_gradients \
|
78 |
+
--track_save_tensors \
|
79 |
+
```
|
80 |
+
#### Modular hardware-drift forensics
|
81 |
+
```console
|
82 |
+
$ python train.py \
|
83 |
+
--experiment_name YOUR-EXPERIMENT-NAME \
|
84 |
+
--run_name YOUR-RUN-NAME \
|
85 |
+
--dataset Microscopy \
|
86 |
+
--adv_training
|
87 |
+
--lr 1e-5 \
|
88 |
+
--n_splits 5 \
|
89 |
+
--epochs 5 \
|
90 |
+
--classifier_pretrained \
|
91 |
+
--processing_mode parametrized \
|
92 |
+
--augmentation weak \
|
93 |
+
--log_model True \
|
94 |
+
--iso 0.01 \
|
95 |
+
--track_processing \
|
96 |
+
--track_every_epoch \
|
97 |
+
--track_predictions \
|
98 |
+
--track_processing_gradients \
|
99 |
+
--track_save_tensors \
|
100 |
+
```
|
101 |
+
#### Image processing customization
|
102 |
+
```console
|
103 |
+
$ python train.py \
|
104 |
+
--experiment_name YOUR-EXPERIMENT-NAME \
|
105 |
+
--run_name YOUR-RUN-NAME \
|
106 |
+
--dataset Microscopy \
|
107 |
+
--lr 1e-5 \
|
108 |
+
--n_splits 5 \
|
109 |
+
--epochs 5 \
|
110 |
+
--classifier_pretrained \
|
111 |
+
--processing_mode parametrized \
|
112 |
+
--augmentation weak \
|
113 |
+
--log_model True \
|
114 |
+
--iso 0.01 \
|
115 |
+
--track_processing \
|
116 |
+
--track_every_epoch \
|
117 |
+
--track_predictions \
|
118 |
+
--track_processing_gradients \
|
119 |
+
--track_save_tensors \
|
120 |
+
```
|
121 |
+
## Virtual lab log
|
122 |
+
We maintain a collaborative virtual lab log at [this address](http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com/#/). There you can browse experiment runs, analyze results through SQL queries and download trained processing and task models.
|
123 |
+

|
124 |
+
|
125 |
+
|
126 |
+
### Review our experiments
|
127 |
+
Experiments are listed in the left column. You can select individual runs or compare metrics and parameters across different runs. For runs where we tracked images of intermediate processing steps and images of the gradients at these processing steps you can find at the bottom of a run page in the *results* folder for each epoch.
|
128 |
+
### Use our trained models
|
129 |
+
When selecting a run and a model was saved you can find the model files, state dictionary and instructions to load at the bottom of a run page under *models*. In the menu bar at the top of the virtual lab log you can also access models via the *Model Registry*. Our code is well integrated with the *mlflow* autologging and -loading package for PyTorch. So when using our code you can just specify the *model uri* as an argument and models will be fetched from the model registry automatically.
|
dataset.py
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import rawpy
|
4 |
+
import random
|
5 |
+
from PIL import Image
|
6 |
+
import tifffile as tiff
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
13 |
+
from sklearn.model_selection import StratifiedShuffleSplit
|
14 |
+
|
15 |
+
if not os.path.exists('README.md'): # set pwd to root
|
16 |
+
os.chdir('..')
|
17 |
+
|
18 |
+
from utils.dataset_utils import split_img, list_images_in_dir, load_image
|
19 |
+
from utils.base import np2torch, torch2np, b2_download_folder
|
20 |
+
|
21 |
+
IMAGE_FILE_TYPES = ['dng', 'png', 'tif', 'tiff']
|
22 |
+
|
23 |
+
|
24 |
+
def get_dataset(name, I_ratio=1.0):
|
25 |
+
# DroneDataset
|
26 |
+
if name in ('DC', 'Drone', 'DroneClassification', 'DroneDatasetClassificationTiled'):
|
27 |
+
return DroneDatasetClassificationTiled(I_ratio=I_ratio)
|
28 |
+
if name in ('DS', 'DroneSegmentation', 'DroneDatasetSegmentationTiled'):
|
29 |
+
return DroneDatasetSegmentationTiled(I_ratio=I_ratio)
|
30 |
+
|
31 |
+
# MicroscopyDataset
|
32 |
+
if name in ('M', 'Microscopy', 'MicroscopyDataset'):
|
33 |
+
return MicroscopyDataset(I_ratio=I_ratio)
|
34 |
+
|
35 |
+
# for testing
|
36 |
+
if name in ('DSF', 'DroneDatasetSegmentationFull'):
|
37 |
+
return DroneDatasetSegmentationFull(I_ratio=I_ratio)
|
38 |
+
if name in ('MRGB', 'MicroscopyRGB', 'MicroscopyDatasetRGB'):
|
39 |
+
return MicroscopyDatasetRGB(I_ratio=I_ratio)
|
40 |
+
|
41 |
+
raise ValueError(name)
|
42 |
+
|
43 |
+
|
44 |
+
class ImageFolderDataset(Dataset):
|
45 |
+
"""Creates a dataset of images in img_dir and corresponding masks in mask_dir.
|
46 |
+
Corresponding mask files need to contain the filename of the image.
|
47 |
+
Files are expected to be of the same filetype.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
img_dir (str): path to image folder
|
51 |
+
mask_dir (str): path to mask folder
|
52 |
+
transform (callable, optional): transformation to apply to image and mask
|
53 |
+
bits (int, optional): normalize image by dividing by 2^bits - 1
|
54 |
+
"""
|
55 |
+
|
56 |
+
task = 'classification'
|
57 |
+
|
58 |
+
def __init__(self, img_dir, labels, transform=None, bits=1):
|
59 |
+
|
60 |
+
self.img_dir = img_dir
|
61 |
+
self.labels = labels
|
62 |
+
|
63 |
+
self.images = list_images_in_dir(img_dir)
|
64 |
+
|
65 |
+
assert len(self.images) == len(self.labels)
|
66 |
+
|
67 |
+
self.transform = transform
|
68 |
+
self.bits = bits
|
69 |
+
|
70 |
+
def __repr__(self):
|
71 |
+
rep = f"{type(self).__name__}: ImageFolderDataset[{len(self.images)}]"
|
72 |
+
for n, (img, label) in enumerate(zip(self.images, self.labels)):
|
73 |
+
rep += f'\nimage: {img}\tlabel: {label}'
|
74 |
+
if n > 10:
|
75 |
+
rep += '\n...'
|
76 |
+
break
|
77 |
+
return rep
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.images)
|
81 |
+
|
82 |
+
def __getitem__(self, idx):
|
83 |
+
|
84 |
+
label = self.labels[idx]
|
85 |
+
|
86 |
+
img = load_image(self.images[idx])
|
87 |
+
img = img / (2**self.bits - 1)
|
88 |
+
if self.transform is not None:
|
89 |
+
img = self.transform(img)
|
90 |
+
|
91 |
+
if len(img.shape) == 2:
|
92 |
+
assert img.shape == (256, 256), f"Invalid size for {self.images[idx]}"
|
93 |
+
else:
|
94 |
+
assert img.shape == (3, 256, 256), f"Invalid size for {self.images[idx]}"
|
95 |
+
|
96 |
+
return img, label
|
97 |
+
|
98 |
+
|
99 |
+
class ImageFolderDatasetSegmentation(Dataset):
|
100 |
+
"""Creates a dataset of images in `img_dir` and corresponding masks in `mask_dir`.
|
101 |
+
Corresponding mask files need to contain the filename of the image.
|
102 |
+
Files are expected to be of the same filetype.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
img_dir (str): path to image folder
|
106 |
+
mask_dir (str): path to mask folder
|
107 |
+
transform (callable, optional): transformation to apply to image and mask
|
108 |
+
bits (int, optional): normalize image by dividing by 2^bits - 1
|
109 |
+
"""
|
110 |
+
|
111 |
+
task = 'segmentation'
|
112 |
+
|
113 |
+
def __init__(self, img_dir, mask_dir, transform=None, bits=1):
|
114 |
+
|
115 |
+
self.img_dir = img_dir
|
116 |
+
self.mask_dir = mask_dir
|
117 |
+
|
118 |
+
self.images = list_images_in_dir(img_dir)
|
119 |
+
self.masks = list_images_in_dir(mask_dir)
|
120 |
+
|
121 |
+
check_image_folder_consistency(self.images, self.masks)
|
122 |
+
|
123 |
+
self.transform = transform
|
124 |
+
self.bits = bits
|
125 |
+
|
126 |
+
def __repr__(self):
|
127 |
+
rep = f"{type(self).__name__}: ImageFolderDatasetSegmentation[{len(self.images)}]"
|
128 |
+
for n, (img, mask) in enumerate(zip(self.images, self.masks)):
|
129 |
+
rep += f'\nimage: {img}\tmask: {mask}'
|
130 |
+
if n > 10:
|
131 |
+
rep += '\n...'
|
132 |
+
break
|
133 |
+
return rep
|
134 |
+
|
135 |
+
def __len__(self):
|
136 |
+
return len(self.images)
|
137 |
+
|
138 |
+
def __getitem__(self, idx):
|
139 |
+
|
140 |
+
img = load_image(self.images[idx])
|
141 |
+
mask = load_image(self.masks[idx])
|
142 |
+
|
143 |
+
img = img / (2**self.bits - 1)
|
144 |
+
mask = (mask > 0).astype(np.float32)
|
145 |
+
|
146 |
+
if self.transform is not None:
|
147 |
+
img = self.transform(img)
|
148 |
+
|
149 |
+
return img, mask
|
150 |
+
|
151 |
+
|
152 |
+
class MultiIntensity(Dataset):
|
153 |
+
"""Wrap datasets with different intesities
|
154 |
+
|
155 |
+
Args:
|
156 |
+
datasets (list): list of datasets to wrap
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, datasets):
|
160 |
+
self.dataset = datasets[0]
|
161 |
+
|
162 |
+
for d in range(1, len(datasets)):
|
163 |
+
self.dataset.images = self.dataset.images + datasets[d].images
|
164 |
+
self.dataset.labels = self.dataset.labels + datasets[d].labels
|
165 |
+
|
166 |
+
def __len__(self):
|
167 |
+
return len(self.dataset)
|
168 |
+
|
169 |
+
def __repr__(self):
|
170 |
+
return f"Subset [{len(self.dataset)}] of " + repr(self.dataset)
|
171 |
+
|
172 |
+
def __getitem__(self, idx):
|
173 |
+
x, y = self.dataset[idx]
|
174 |
+
if self.transform is not None:
|
175 |
+
x = self.transform(x)
|
176 |
+
return x, y
|
177 |
+
|
178 |
+
|
179 |
+
class Subset(Dataset):
|
180 |
+
"""Define a subset of a dataset by only selecting given indices.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
dataset (Dataset): full dataset
|
184 |
+
indices (list): subset indices
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self, dataset, indices=None, transform=None):
|
188 |
+
self.dataset = dataset
|
189 |
+
self.indices = indices if indices is not None else range(len(dataset))
|
190 |
+
self.transform = transform
|
191 |
+
|
192 |
+
def __len__(self):
|
193 |
+
return len(self.indices)
|
194 |
+
|
195 |
+
def __repr__(self):
|
196 |
+
return f"Subset [{len(self)}] of " + repr(self.dataset)
|
197 |
+
|
198 |
+
def __getitem__(self, idx):
|
199 |
+
x, y = self.dataset[self.indices[idx]]
|
200 |
+
if self.transform is not None:
|
201 |
+
x = self.transform(x)
|
202 |
+
return x, y
|
203 |
+
|
204 |
+
|
205 |
+
class DroneDatasetSegmentationFull(ImageFolderDatasetSegmentation):
|
206 |
+
"""Dataset consisting of full-sized numpy images and masks. Images are normalized to range [0, 1].
|
207 |
+
"""
|
208 |
+
|
209 |
+
black_level = [0.0625, 0.0626, 0.0625, 0.0626]
|
210 |
+
white_balance = [2.86653646, 1., 1.73079425]
|
211 |
+
colour_matrix = [1.50768983, -0.33571374, -0.17197604, -0.23048614,
|
212 |
+
1.70698738, -0.47650126, -0.03119153, -0.32803956, 1.35923111]
|
213 |
+
camera_parameters = black_level, white_balance, colour_matrix
|
214 |
+
|
215 |
+
def __init__(self, I_ratio=1.0, transform=None, force_download=False, bits=16):
|
216 |
+
|
217 |
+
assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
|
218 |
+
|
219 |
+
img_dir = f'data/drone/images_full/raw_scale{int(I_ratio*100):03d}'
|
220 |
+
mask_dir = 'data/drone/masks_full'
|
221 |
+
|
222 |
+
download_drone_dataset(force_download) # XXX: zip files and add checksum? date?
|
223 |
+
|
224 |
+
super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=bits)
|
225 |
+
|
226 |
+
|
227 |
+
class DroneDatasetSegmentationTiled(ImageFolderDatasetSegmentation):
|
228 |
+
"""Dataset consisting of tiled numpy images and masks. Images are in range [0, 1]
|
229 |
+
Args:
|
230 |
+
tile_size (int, optional): size of the tiled images. Defaults to 256.
|
231 |
+
"""
|
232 |
+
|
233 |
+
camera_parameters = DroneDatasetSegmentationFull.camera_parameters
|
234 |
+
|
235 |
+
def __init__(self, I_ratio=1.0, transform=None):
|
236 |
+
|
237 |
+
tile_size = 256
|
238 |
+
|
239 |
+
img_dir = f'data/drone/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}'
|
240 |
+
mask_dir = f'data/drone/masks_tiles_{tile_size}'
|
241 |
+
|
242 |
+
if not os.path.exists(img_dir) or not os.path.exists(mask_dir):
|
243 |
+
dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1)
|
244 |
+
print("tiling dataset..")
|
245 |
+
create_tiles_dataset(dataset_full, img_dir, mask_dir, tile_size=tile_size)
|
246 |
+
|
247 |
+
super().__init__(img_dir=img_dir, mask_dir=mask_dir, transform=transform, bits=16)
|
248 |
+
|
249 |
+
|
250 |
+
class DroneDatasetClassificationTiled(ImageFolderDataset):
|
251 |
+
|
252 |
+
camera_parameters = DroneDatasetSegmentationFull.camera_parameters
|
253 |
+
|
254 |
+
def __init__(self, I_ratio=1.0, transform=None):
|
255 |
+
|
256 |
+
random_state = 72
|
257 |
+
tile_size = 256
|
258 |
+
thr = 0.01
|
259 |
+
|
260 |
+
img_dir = f'data/drone/classification/images_tiles_{tile_size}/raw_scale{int(I_ratio*100):03d}_thr_{thr}'
|
261 |
+
mask_dir = f'data/drone/classification/masks_tiles_{tile_size}_thr_{thr}'
|
262 |
+
df_path = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv'
|
263 |
+
|
264 |
+
if not os.path.exists(img_dir) or not os.path.exists(mask_dir):
|
265 |
+
dataset_full = DroneDatasetSegmentationFull(I_ratio=I_ratio, bits=1)
|
266 |
+
print("tiling dataset..")
|
267 |
+
create_tiles_dataset_binary(dataset_full, img_dir, mask_dir, random_state, thr, tile_size=tile_size)
|
268 |
+
|
269 |
+
self.classes = ['car', 'no car']
|
270 |
+
self.df = pd.read_csv(df_path)
|
271 |
+
labels = self.df['label'].to_list()
|
272 |
+
|
273 |
+
super().__init__(img_dir=img_dir, labels=labels, transform=transform, bits=16)
|
274 |
+
|
275 |
+
images, class_labels = read_label_csv(self.df)
|
276 |
+
self.images = [os.path.join(self.img_dir, image) for image in images]
|
277 |
+
self.labels = class_labels
|
278 |
+
|
279 |
+
|
280 |
+
class MicroscopyDataset(ImageFolderDataset):
|
281 |
+
"""MicroscopyDataset raw images
|
282 |
+
|
283 |
+
Args:
|
284 |
+
I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0]
|
285 |
+
raw (bool): Select rgb dataset or raw dataset
|
286 |
+
transform (callable, optional): transformation to apply to image and mask
|
287 |
+
bits (int, optional): normalize image by dividing by 2^bits - 1
|
288 |
+
"""
|
289 |
+
|
290 |
+
black_level = [9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06]
|
291 |
+
white_balance = [-0.6567, 1.9673, 3.5304]
|
292 |
+
colour_matrix = [-2.0338, 0.0933, 0.4157, -0.0286, 2.6464, -0.0574, -0.5516, -0.0947, 2.9308]
|
293 |
+
|
294 |
+
camera_parameters = black_level, white_balance, colour_matrix
|
295 |
+
|
296 |
+
dataset_mean = [0.91, 0.84, 0.94]
|
297 |
+
dataset_std = [0.08, 0.12, 0.05]
|
298 |
+
|
299 |
+
def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False):
|
300 |
+
|
301 |
+
assert I_ratio in [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
|
302 |
+
|
303 |
+
download_microscopy_dataset(force_download=force_download)
|
304 |
+
|
305 |
+
self.img_dir = f'data/microscopy/images/raw_scale{int(I_ratio*100):03d}'
|
306 |
+
self.transform = transform
|
307 |
+
self.bits = bits
|
308 |
+
|
309 |
+
self.label_file = 'data/microscopy/labels/Ma190c_annotations.dat'
|
310 |
+
|
311 |
+
self.valid_classes = ['BAS', 'EBO', 'EOS', 'KSC', 'LYA', 'LYT', 'MMZ', 'MOB',
|
312 |
+
'MON', 'MYB', 'MYO', 'NGB', 'NGS', 'PMB', 'PMO', 'UNC']
|
313 |
+
|
314 |
+
self.invalid_files = ['Ma190c_lame3_zone13_composite_Mcropped_2.tiff', ]
|
315 |
+
|
316 |
+
images, class_labels = read_label_file(self.label_file)
|
317 |
+
|
318 |
+
# filter classes with low appearance
|
319 |
+
self.valid_classes = [class_label for class_label in self.valid_classes
|
320 |
+
if class_labels.count(class_label) > 4]
|
321 |
+
|
322 |
+
# remove invalid classes and invalid files from (images, class_labels)
|
323 |
+
images, class_labels = list(zip(*[
|
324 |
+
(image, class_label)
|
325 |
+
for image, class_label in zip(images, class_labels)
|
326 |
+
if class_label in self.valid_classes and image not in self.invalid_files
|
327 |
+
]))
|
328 |
+
|
329 |
+
self.classes = list(sorted({*class_labels}))
|
330 |
+
|
331 |
+
# store full path
|
332 |
+
self.images = [os.path.join(self.img_dir, image) for image in images]
|
333 |
+
|
334 |
+
# reindex labels
|
335 |
+
self.labels = [self.classes.index(class_label) for class_label in class_labels]
|
336 |
+
|
337 |
+
|
338 |
+
class MicroscopyDatasetRGB(MicroscopyDataset):
|
339 |
+
"""MicroscopyDataset RGB images
|
340 |
+
|
341 |
+
Args:
|
342 |
+
I_ratio (float): Original image rescaled by this factor, possible values [0.01,0.05,0.1,0.25,0.5,0.75,1.0]
|
343 |
+
raw (bool): Select rgb dataset or raw dataset
|
344 |
+
transform (callable, optional): transformation to apply to image and mask
|
345 |
+
bits (int, optional): normalize image by dividing by 2^bits - 1
|
346 |
+
"""
|
347 |
+
camera_parameters = None
|
348 |
+
|
349 |
+
dataset_mean = None
|
350 |
+
dataset_std = None
|
351 |
+
|
352 |
+
def __init__(self, I_ratio=1.0, transform=None, bits=16, force_download=False):
|
353 |
+
super().__init__(I_ratio=I_ratio, transform=transform, bits=bits, force_download=force_download)
|
354 |
+
self.images = [image.replace('raw', 'rgb') for image in self.images] # XXX: hack
|
355 |
+
|
356 |
+
|
357 |
+
def read_label_file(label_file_path):
|
358 |
+
|
359 |
+
images = []
|
360 |
+
class_labels = []
|
361 |
+
|
362 |
+
with open(label_file_path, "rb") as data:
|
363 |
+
for line in data:
|
364 |
+
file_name, class_label = line.decode("utf-8").split()
|
365 |
+
image = file_name + '.tiff'
|
366 |
+
images.append(image)
|
367 |
+
class_labels.append(class_label)
|
368 |
+
|
369 |
+
return images, class_labels
|
370 |
+
|
371 |
+
|
372 |
+
def read_label_csv(df):
|
373 |
+
|
374 |
+
images = []
|
375 |
+
class_labels = []
|
376 |
+
|
377 |
+
for file_name, label in zip(df['file name'], df['label']):
|
378 |
+
image = file_name + '.tif'
|
379 |
+
images.append(image)
|
380 |
+
class_labels.append(int(label))
|
381 |
+
return images, class_labels
|
382 |
+
|
383 |
+
|
384 |
+
def download_drone_dataset(force_download):
|
385 |
+
b2_download_folder('drone/images', 'data/drone/images_full', force_download=force_download)
|
386 |
+
b2_download_folder('drone/masks', 'data/drone/masks_full', force_download=force_download)
|
387 |
+
unzip_drone_images()
|
388 |
+
|
389 |
+
|
390 |
+
def download_microscopy_dataset(force_download):
|
391 |
+
b2_download_folder('Data histopathology/WhiteCellsImages',
|
392 |
+
'data/microscopy/images', force_download=force_download)
|
393 |
+
b2_download_folder('Data histopathology/WhiteCellsLabels',
|
394 |
+
'data/microscopy/labels', force_download=force_download)
|
395 |
+
unzip_microscopy_images()
|
396 |
+
|
397 |
+
|
398 |
+
def unzip_microscopy_images():
|
399 |
+
|
400 |
+
if os.path.isfile('data/microscopy/labels/.bzEmpty'):
|
401 |
+
os.remove('data/microscopy/labels/.bzEmpty')
|
402 |
+
|
403 |
+
for file in os.listdir('data/microscopy/images'):
|
404 |
+
if file.endswith(".zip"):
|
405 |
+
zip = zipfile.ZipFile(os.path.join('data/microscopy/images', file))
|
406 |
+
zip.extractall('data/microscopy/images')
|
407 |
+
os.remove(os.path.join('data/microscopy/images', file))
|
408 |
+
|
409 |
+
|
410 |
+
def unzip_drone_images():
|
411 |
+
|
412 |
+
if os.path.isfile('data/drone/masks_full/.bzEmpty'):
|
413 |
+
os.remove('data/drone/masks_full/.bzEmpty')
|
414 |
+
|
415 |
+
for file in os.listdir('data/drone/images_full'):
|
416 |
+
if file.endswith(".zip"):
|
417 |
+
zip = zipfile.ZipFile(os.path.join('data/drone/images_full', file))
|
418 |
+
zip.extractall('data/drone/images_full')
|
419 |
+
os.remove(os.path.join('data/drone/images_full', file))
|
420 |
+
|
421 |
+
|
422 |
+
def create_tiles_dataset(dataset, img_dir, mask_dir, tile_size=256):
|
423 |
+
for folder in [img_dir, mask_dir]:
|
424 |
+
if not os.path.exists(folder):
|
425 |
+
os.makedirs(folder)
|
426 |
+
for n, (img, mask) in enumerate(dataset):
|
427 |
+
tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
|
428 |
+
tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
|
429 |
+
tiled_img, tiled_mask = class_detection(tiled_img, tiled_mask) # Remove images without cars in it
|
430 |
+
for i, (sub_img, sub_mask) in enumerate(zip(tiled_img, tiled_mask)):
|
431 |
+
tile_id = f"{n:02d}_{i:05d}"
|
432 |
+
Image.fromarray(sub_img).save(os.path.join(img_dir, tile_id + '.tif'))
|
433 |
+
Image.fromarray(sub_mask > 0).save(os.path.join(mask_dir, tile_id + '.png'))
|
434 |
+
|
435 |
+
|
436 |
+
def create_tiles_dataset_binary(dataset, img_dir, mask_dir, random_state, thr, tile_size=256):
|
437 |
+
|
438 |
+
for folder in [img_dir, mask_dir]:
|
439 |
+
if not os.path.exists(folder):
|
440 |
+
os.makedirs(folder)
|
441 |
+
|
442 |
+
ids = []
|
443 |
+
labels = []
|
444 |
+
|
445 |
+
for n, (img, mask) in enumerate(dataset):
|
446 |
+
tiled_img = split_img(img, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
|
447 |
+
tiled_mask = split_img(mask, ROIs=(tile_size, tile_size), step=(tile_size, tile_size))
|
448 |
+
|
449 |
+
X_with, X_without, Y_with, Y_without = binary_class_detection(
|
450 |
+
tiled_img, tiled_mask, random_state, thr) # creates balanced arrays with class and without class
|
451 |
+
|
452 |
+
for i, (sub_X_with, sub_Y_with) in enumerate(zip(X_with, Y_with)):
|
453 |
+
tile_id = f"{n:02d}_{i:05d}"
|
454 |
+
ids.append(tile_id)
|
455 |
+
labels.append(0)
|
456 |
+
Image.fromarray(sub_X_with).save(os.path.join(img_dir, tile_id + '.tif'))
|
457 |
+
Image.fromarray(sub_Y_with > 0).save(os.path.join(mask_dir, tile_id + '.png'))
|
458 |
+
for j, (sub_X_without, sub_Y_without) in enumerate(zip(X_without, Y_without)):
|
459 |
+
tile_id = f"{n:02d}_{i+1+j:05d}"
|
460 |
+
ids.append(tile_id)
|
461 |
+
labels.append(1)
|
462 |
+
Image.fromarray(sub_X_without).save(os.path.join(img_dir, tile_id + '.tif'))
|
463 |
+
Image.fromarray(sub_Y_without > 0).save(os.path.join(mask_dir, tile_id + '.png'))
|
464 |
+
# Image.fromarray(sub_mask).save(os.path.join(mask_dir, tile_id + '.png'))
|
465 |
+
|
466 |
+
df = pd.DataFrame({'file name': ids, 'label': labels})
|
467 |
+
|
468 |
+
df_loc = f'data/drone/classification/dataset_tiles_{tile_size}_{random_state}_{thr}.csv'
|
469 |
+
df.to_csv(df_loc)
|
470 |
+
|
471 |
+
return
|
472 |
+
|
473 |
+
|
474 |
+
def class_detection(X, Y):
|
475 |
+
"""Split dataset in images which has the class in the target
|
476 |
+
|
477 |
+
Args:
|
478 |
+
X (ndarray): input image
|
479 |
+
Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class)
|
480 |
+
Returns:
|
481 |
+
X_with_class (ndarray): input regions with the selected class
|
482 |
+
Y_with_class (ndarray): target regions with the selected class
|
483 |
+
X_without_class (ndarray): input regions without the selected class
|
484 |
+
Y_without_class (ndarray): target regions without the selected class
|
485 |
+
"""
|
486 |
+
|
487 |
+
with_class = []
|
488 |
+
without_class = []
|
489 |
+
for i, img in enumerate(Y):
|
490 |
+
if img.mean() == 0:
|
491 |
+
without_class.append(i)
|
492 |
+
else:
|
493 |
+
with_class.append(i)
|
494 |
+
|
495 |
+
X_with_class = np.delete(X, without_class, 0)
|
496 |
+
Y_with_class = np.delete(Y, without_class, 0)
|
497 |
+
|
498 |
+
return X_with_class, Y_with_class
|
499 |
+
|
500 |
+
|
501 |
+
def binary_class_detection(X, Y, random_seed, thr):
|
502 |
+
"""Splits subimages in subimages with the selected class and without the selected class by calculating the mean of the submasks; subimages with 0 < submask.mean()<=thr are disregared
|
503 |
+
|
504 |
+
|
505 |
+
|
506 |
+
Args:
|
507 |
+
X (ndarray): input image
|
508 |
+
Y (ndarray): target with segmentation map (images with {0,1} values where it is 1 when there is the class)
|
509 |
+
thr (flaot): sub images are not considered if 0 < sub_target.mean() <= thr
|
510 |
+
balanced (bool): number of returned sub images is equal for both classes if true
|
511 |
+
random_seed (None or int): selection of sub images in class with more elements according to random_seed if balanced
|
512 |
+
Returns:
|
513 |
+
X_with_class (ndarray): input regions with the selected class
|
514 |
+
Y_with_class (ndarray): target regions with the selected class
|
515 |
+
X_without_class (ndarray): input regions without the selected class
|
516 |
+
Y_without_class (ndarray): target regions without the selected class
|
517 |
+
"""
|
518 |
+
|
519 |
+
with_class = []
|
520 |
+
without_class = []
|
521 |
+
no_class = []
|
522 |
+
|
523 |
+
for i, img in enumerate(Y):
|
524 |
+
m = img.mean()
|
525 |
+
if m == 0:
|
526 |
+
without_class.append(i)
|
527 |
+
else:
|
528 |
+
if m > thr:
|
529 |
+
with_class.append(i)
|
530 |
+
else:
|
531 |
+
no_class.append(i)
|
532 |
+
|
533 |
+
N = len(with_class)
|
534 |
+
M = len(without_class)
|
535 |
+
random.seed(random_seed)
|
536 |
+
if N <= M:
|
537 |
+
random.shuffle(without_class)
|
538 |
+
with_class.extend(without_class[:M - N])
|
539 |
+
else:
|
540 |
+
random.shuffle(with_class)
|
541 |
+
without_class.extend(with_class[:N - M])
|
542 |
+
|
543 |
+
X_with_class = np.delete(X, without_class + no_class, 0)
|
544 |
+
X_without_class = np.delete(X, with_class + no_class, 0)
|
545 |
+
Y_with_class = np.delete(Y, without_class + no_class, 0)
|
546 |
+
Y_without_class = np.delete(Y, with_class + no_class, 0)
|
547 |
+
|
548 |
+
return X_with_class, X_without_class, Y_with_class, Y_without_class
|
549 |
+
|
550 |
+
|
551 |
+
def make_dataloader(dataset, batch_size, shuffle=True):
|
552 |
+
|
553 |
+
X, Y = dataset
|
554 |
+
|
555 |
+
X, Y = np2torch(X), np2torch(Y)
|
556 |
+
|
557 |
+
dataset = TensorDataset(X, Y)
|
558 |
+
dataset = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
559 |
+
|
560 |
+
return dataset
|
561 |
+
|
562 |
+
|
563 |
+
def check_image_folder_consistency(images, masks):
|
564 |
+
file_type_images = images[0].split('.')[-1].lower()
|
565 |
+
file_type_masks = masks[0].split('.')[-1].lower()
|
566 |
+
assert len(images) == len(masks), "images / masks length mismatch"
|
567 |
+
for img_file, mask_file in zip(images, masks):
|
568 |
+
img_name = img_file.split('/')[-1].split('.')[0]
|
569 |
+
assert img_name in mask_file, f"image {img_file} corresponds to {mask_file}?"
|
570 |
+
assert img_file.split('.')[-1].lower() == file_type_images, \
|
571 |
+
f"image file {img_file} file type mismatch. Shoule be: {file_type_images}"
|
572 |
+
assert mask_file.split('.')[-1].lower() == file_type_masks, \
|
573 |
+
f"image file {mask_file} file type mismatch. Should be: {file_type_masks}"
|
demo-files/car.png
ADDED
![]() |
demo-files/micro.png
ADDED
![]() |
demo.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
#import tensorflow as tf
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
from os.path import dirname, realpath, join
|
6 |
+
import processing.pipeline_numpy as ppn
|
7 |
+
|
8 |
+
|
9 |
+
# Load human-readable labels for ImageNet.
|
10 |
+
current_dir = dirname(realpath(__file__))
|
11 |
+
|
12 |
+
|
13 |
+
def process(RawImage, CameraParameters, Debayer, Sharpening, Denoising):
|
14 |
+
raw_img = RawImage
|
15 |
+
if CameraParameters == "Microscope":
|
16 |
+
black_level = [9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06, 9.834368023181512e-06]
|
17 |
+
white_balance = [-0.6567, 1.9673, 3.5304]
|
18 |
+
colour_matrix = [-2.0338, 0.0933, 0.4157, -0.0286, 2.6464, -0.0574, -0.5516, -0.0947, 2.9308]
|
19 |
+
elif CameraParameters == "Drone":
|
20 |
+
#drone
|
21 |
+
black_level = [0.0625, 0.0626, 0.0625, 0.0626]
|
22 |
+
white_balance = [2.86653646, 1., 1.73079425]
|
23 |
+
colour_matrix = [1.50768983, -0.33571374, -0.17197604, -0.23048614,
|
24 |
+
1.70698738, -0.47650126, -0.03119153, -0.32803956, 1.35923111]
|
25 |
+
else:
|
26 |
+
print("No valid camera parameter")
|
27 |
+
debayer = Debayer
|
28 |
+
sharpening = Sharpening
|
29 |
+
denoising = Denoising
|
30 |
+
print(np.max(raw_img))
|
31 |
+
raw_img = (raw_img[:,:,0].astype(np.float64)/255.)
|
32 |
+
img = ppn.processing(raw_img, black_level, white_balance, colour_matrix,
|
33 |
+
debayer=debayer, sharpening=sharpening, denoising=denoising)
|
34 |
+
print(np.max(img))
|
35 |
+
return img
|
36 |
+
|
37 |
+
|
38 |
+
iface = gr.Interface(
|
39 |
+
process,
|
40 |
+
[gr.inputs.Image(),gr.inputs.Radio(["Microscope", "Drone"]),gr.inputs.Dropdown(["bilinear", "malvar2004", "menon2007"]),
|
41 |
+
gr.inputs.Dropdown(["sharpening_filter", "unsharp_masking"]),
|
42 |
+
gr.inputs.Dropdown(["gaussian_denoising", "median_denoising"])],
|
43 |
+
"image",
|
44 |
+
capture_session=True,
|
45 |
+
examples=[
|
46 |
+
["demo-files/car.png"],
|
47 |
+
["demo-files/micro.png"]
|
48 |
+
],
|
49 |
+
title="Lens2Logit - Static processing demo",
|
50 |
+
description="You can select a sample raw image, the camera parameters and the pipeline configuration to process the raw image.",)
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
iface.launch(share=True)
|
54 |
+
|
environment.yml
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: perturbed
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- _ipyw_jlab_nb_ext_conf=0.1.0=py37_0
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- alabaster=0.7.12=py37_0
|
8 |
+
- anaconda=2019.10=py37_0
|
9 |
+
- anaconda-client=1.7.2=py37_0
|
10 |
+
- anaconda-navigator=1.9.7=py37_0
|
11 |
+
- anaconda-project=0.8.3=py_0
|
12 |
+
- asn1crypto=1.0.1=py37_0
|
13 |
+
- astroid=2.3.1=py37_0
|
14 |
+
- astropy=3.2.2=py37h7b6447c_0
|
15 |
+
- atomicwrites=1.3.0=py37_1
|
16 |
+
- attrs=19.2.0=py_0
|
17 |
+
- babel=2.7.0=py_0
|
18 |
+
- backcall=0.1.0=py37_0
|
19 |
+
- backports=1.0=py_2
|
20 |
+
- backports.functools_lru_cache=1.5=py_2
|
21 |
+
- backports.os=0.1.1=py37_0
|
22 |
+
- backports.shutil_get_terminal_size=1.0.0=py37_2
|
23 |
+
- backports.tempfile=1.0=py_1
|
24 |
+
- backports.weakref=1.0.post1=py_1
|
25 |
+
- beautifulsoup4=4.8.0=py37_0
|
26 |
+
- bitarray=1.0.1=py37h7b6447c_0
|
27 |
+
- bkcharts=0.2=py37_0
|
28 |
+
- blas=1.0=mkl
|
29 |
+
- bleach=3.1.0=py37_0
|
30 |
+
- blosc=1.16.3=hd408876_0
|
31 |
+
- bokeh=1.3.4=py37_0
|
32 |
+
- boto=2.49.0=py37_0
|
33 |
+
- bottleneck=1.2.1=py37h035aef0_1
|
34 |
+
- bzip2=1.0.8=h7b6447c_0
|
35 |
+
- ca-certificates=2019.8.28=0
|
36 |
+
- cairo=1.14.12=h8948797_3
|
37 |
+
- certifi=2019.9.11=py37_0
|
38 |
+
- cffi=1.12.3=py37h2e261b9_0
|
39 |
+
- chardet=3.0.4=py37_1003
|
40 |
+
- click=7.0=py37_0
|
41 |
+
- cloudpickle=1.2.2=py_0
|
42 |
+
- clyent=1.2.2=py37_1
|
43 |
+
- colorama=0.4.1=py37_0
|
44 |
+
- conda-package-handling=1.6.0=py37h7b6447c_0
|
45 |
+
- conda-verify=3.4.2=py_1
|
46 |
+
- contextlib2=0.6.0=py_0
|
47 |
+
- cryptography=2.7=py37h1ba5d50_0
|
48 |
+
- curl=7.65.3=hbc83047_0
|
49 |
+
- cycler=0.10.0=py37_0
|
50 |
+
- cython=0.29.13=py37he6710b0_0
|
51 |
+
- cytoolz=0.10.0=py37h7b6447c_0
|
52 |
+
- dask=2.5.2=py_0
|
53 |
+
- dask-core=2.5.2=py_0
|
54 |
+
- dbus=1.13.6=h746ee38_0
|
55 |
+
- decorator=4.4.0=py37_1
|
56 |
+
- defusedxml=0.6.0=py_0
|
57 |
+
- distributed=2.5.2=py_0
|
58 |
+
- docutils=0.15.2=py37_0
|
59 |
+
- entrypoints=0.3=py37_0
|
60 |
+
- et_xmlfile=1.0.1=py37_0
|
61 |
+
- expat=2.2.6=he6710b0_0
|
62 |
+
- fastcache=1.1.0=py37h7b6447c_0
|
63 |
+
- filelock=3.0.12=py_0
|
64 |
+
- flask=1.1.1=py_0
|
65 |
+
- fontconfig=2.13.0=h9420a91_0
|
66 |
+
- freetype=2.9.1=h8a8886c_1
|
67 |
+
- fribidi=1.0.5=h7b6447c_0
|
68 |
+
- future=0.17.1=py37_0
|
69 |
+
- get_terminal_size=1.0.0=haa9412d_0
|
70 |
+
- gevent=1.4.0=py37h7b6447c_0
|
71 |
+
- glib=2.56.2=hd408876_0
|
72 |
+
- glob2=0.7=py_0
|
73 |
+
- gmp=6.1.2=h6c8ec71_1
|
74 |
+
- gmpy2=2.0.8=py37h10f8cd9_2
|
75 |
+
- graphite2=1.3.13=h23475e2_0
|
76 |
+
- greenlet=0.4.15=py37h7b6447c_0
|
77 |
+
- gst-plugins-base=1.14.0=hbbd80ab_1
|
78 |
+
- gstreamer=1.14.0=hb453b48_1
|
79 |
+
- h5py=2.9.0=py37h7918eee_0
|
80 |
+
- harfbuzz=1.8.8=hffaf4a1_0
|
81 |
+
- hdf5=1.10.4=hb1b8bf9_0
|
82 |
+
- heapdict=1.0.1=py_0
|
83 |
+
- html5lib=1.0.1=py37_0
|
84 |
+
- icu=58.2=h9c2bf20_1
|
85 |
+
- idna=2.8=py37_0
|
86 |
+
- imageio=2.6.0=py37_0
|
87 |
+
- imagesize=1.1.0=py37_0
|
88 |
+
- intel-openmp=2019.4=243
|
89 |
+
- ipykernel=5.1.2=py37h39e3cac_0
|
90 |
+
- ipython=7.8.0=py37h39e3cac_0
|
91 |
+
- ipython_genutils=0.2.0=py37_0
|
92 |
+
- ipywidgets=7.5.1=py_0
|
93 |
+
- isort=4.3.21=py37_0
|
94 |
+
- itsdangerous=1.1.0=py37_0
|
95 |
+
- jbig=2.1=hdba287a_0
|
96 |
+
- jdcal=1.4.1=py_0
|
97 |
+
- jedi=0.15.1=py37_0
|
98 |
+
- jeepney=0.4.1=py_0
|
99 |
+
- jinja2=2.10.3=py_0
|
100 |
+
- joblib=0.13.2=py37_0
|
101 |
+
- jpeg=9b=h024ee3a_2
|
102 |
+
- json5=0.8.5=py_0
|
103 |
+
- jsonschema=3.0.2=py37_0
|
104 |
+
- jupyter=1.0.0=py37_7
|
105 |
+
- jupyter_client=5.3.3=py37_1
|
106 |
+
- jupyter_console=6.0.0=py37_0
|
107 |
+
- jupyter_core=4.5.0=py_0
|
108 |
+
- jupyterlab=1.1.4=pyhf63ae98_0
|
109 |
+
- jupyterlab_server=1.0.6=py_0
|
110 |
+
- keyring=18.0.0=py37_0
|
111 |
+
- kiwisolver=1.1.0=py37he6710b0_0
|
112 |
+
- krb5=1.16.1=h173b8e3_7
|
113 |
+
- lazy-object-proxy=1.4.2=py37h7b6447c_0
|
114 |
+
- libarchive=3.3.3=h5d8350f_5
|
115 |
+
- libcurl=7.65.3=h20c2e04_0
|
116 |
+
- libedit=3.1.20181209=hc058e9b_0
|
117 |
+
- libffi=3.2.1=hd88cf55_4
|
118 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
119 |
+
- libgfortran-ng=7.3.0=hdf63c60_0
|
120 |
+
- liblief=0.9.0=h7725739_2
|
121 |
+
- libpng=1.6.37=hbc83047_0
|
122 |
+
- libsodium=1.0.16=h1bed415_0
|
123 |
+
- libssh2=1.8.2=h1ba5d50_0
|
124 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
125 |
+
- libtiff=4.0.10=h2733197_2
|
126 |
+
- libtool=2.4.6=h7b6447c_5
|
127 |
+
- libuuid=1.0.3=h1bed415_2
|
128 |
+
- libxcb=1.13=h1bed415_1
|
129 |
+
- libxml2=2.9.9=hea5a465_1
|
130 |
+
- libxslt=1.1.33=h7d1a2b0_0
|
131 |
+
- llvmlite=0.29.0=py37hd408876_0
|
132 |
+
- locket=0.2.0=py37_1
|
133 |
+
- lxml=4.4.1=py37hefd8a0e_0
|
134 |
+
- lz4-c=1.8.1.2=h14c3975_0
|
135 |
+
- lzo=2.10=h49e0be7_2
|
136 |
+
- markupsafe=1.1.1=py37h7b6447c_0
|
137 |
+
- matplotlib=3.1.1=py37h5429711_0
|
138 |
+
- mccabe=0.6.1=py37_1
|
139 |
+
- mistune=0.8.4=py37h7b6447c_0
|
140 |
+
- mkl=2019.4=243
|
141 |
+
- mkl-service=2.3.0=py37he904b0f_0
|
142 |
+
- mkl_fft=1.0.14=py37ha843d7b_0
|
143 |
+
- mkl_random=1.1.0=py37hd6b4f25_0
|
144 |
+
- mock=3.0.5=py37_0
|
145 |
+
- more-itertools=7.2.0=py37_0
|
146 |
+
- mpc=1.1.0=h10f8cd9_1
|
147 |
+
- mpfr=4.0.1=hdf1c602_3
|
148 |
+
- mpmath=1.1.0=py37_0
|
149 |
+
- msgpack-python=0.6.1=py37hfd86e86_1
|
150 |
+
- multipledispatch=0.6.0=py37_0
|
151 |
+
- navigator-updater=0.2.1=py37_0
|
152 |
+
- nbconvert=5.6.0=py37_1
|
153 |
+
- nbformat=4.4.0=py37_0
|
154 |
+
- ncurses=6.1=he6710b0_1
|
155 |
+
- networkx=2.3=py_0
|
156 |
+
- nltk=3.4.5=py37_0
|
157 |
+
- nose=1.3.7=py37_2
|
158 |
+
- notebook=6.0.1=py37_0
|
159 |
+
- numba=0.45.1=py37h962f231_0
|
160 |
+
- numexpr=2.7.0=py37h9e4a6bb_0
|
161 |
+
- numpy=1.17.2=py37haad9e8e_0
|
162 |
+
- numpy-base=1.17.2=py37hde5b4d6_0
|
163 |
+
- numpydoc=0.9.1=py_0
|
164 |
+
- olefile=0.46=py37_0
|
165 |
+
- openpyxl=3.0.0=py_0
|
166 |
+
- openssl=1.1.1d=h7b6447c_2
|
167 |
+
- packaging=19.2=py_0
|
168 |
+
- pandoc=2.2.3.2=0
|
169 |
+
- pandocfilters=1.4.2=py37_1
|
170 |
+
- pango=1.42.4=h049681c_0
|
171 |
+
- parso=0.5.1=py_0
|
172 |
+
- partd=1.0.0=py_0
|
173 |
+
- patchelf=0.9=he6710b0_3
|
174 |
+
- path.py=12.0.1=py_0
|
175 |
+
- pathlib2=2.3.5=py37_0
|
176 |
+
- patsy=0.5.1=py37_0
|
177 |
+
- pcre=8.43=he6710b0_0
|
178 |
+
- pep8=1.7.1=py37_0
|
179 |
+
- pexpect=4.7.0=py37_0
|
180 |
+
- pickleshare=0.7.5=py37_0
|
181 |
+
- pip=19.2.3=py37_0
|
182 |
+
- pixman=0.38.0=h7b6447c_0
|
183 |
+
- pkginfo=1.5.0.1=py37_0
|
184 |
+
- pluggy=0.13.0=py37_0
|
185 |
+
- ply=3.11=py37_0
|
186 |
+
- prometheus_client=0.7.1=py_0
|
187 |
+
- prompt_toolkit=2.0.10=py_0
|
188 |
+
- psutil=5.6.3=py37h7b6447c_0
|
189 |
+
- ptyprocess=0.6.0=py37_0
|
190 |
+
- py=1.8.0=py37_0
|
191 |
+
- py-lief=0.9.0=py37h7725739_2
|
192 |
+
- pycodestyle=2.5.0=py37_0
|
193 |
+
- pycosat=0.6.3=py37h14c3975_0
|
194 |
+
- pycparser=2.19=py37_0
|
195 |
+
- pycrypto=2.6.1=py37h14c3975_9
|
196 |
+
- pycurl=7.43.0.3=py37h1ba5d50_0
|
197 |
+
- pyflakes=2.1.1=py37_0
|
198 |
+
- pygments=2.4.2=py_0
|
199 |
+
- pylint=2.4.2=py37_0
|
200 |
+
- pyodbc=4.0.27=py37he6710b0_0
|
201 |
+
- pyopenssl=19.0.0=py37_0
|
202 |
+
- pyparsing=2.4.2=py_0
|
203 |
+
- pyqt=5.9.2=py37h05f1152_2
|
204 |
+
- pyrsistent=0.15.4=py37h7b6447c_0
|
205 |
+
- pysocks=1.7.1=py37_0
|
206 |
+
- pytables=3.5.2=py37h71ec239_1
|
207 |
+
- pytest=5.2.1=py37_0
|
208 |
+
- pytest-arraydiff=0.3=py37h39e3cac_0
|
209 |
+
- pytest-astropy=0.5.0=py37_0
|
210 |
+
- pytest-doctestplus=0.4.0=py_0
|
211 |
+
- pytest-openfiles=0.4.0=py_0
|
212 |
+
- pytest-remotedata=0.3.2=py37_0
|
213 |
+
- python=3.7.4=h265db76_1
|
214 |
+
- python-dateutil=2.8.0=py37_0
|
215 |
+
- python-libarchive-c=2.8=py37_13
|
216 |
+
- pytz=2019.3=py_0
|
217 |
+
- pyyaml=5.1.2=py37h7b6447c_0
|
218 |
+
- pyzmq=18.1.0=py37he6710b0_0
|
219 |
+
- qt=5.9.7=h5867ecd_1
|
220 |
+
- qtawesome=0.6.0=py_0
|
221 |
+
- qtconsole=4.5.5=py_0
|
222 |
+
- qtpy=1.9.0=py_0
|
223 |
+
- readline=7.0=h7b6447c_5
|
224 |
+
- requests=2.22.0=py37_0
|
225 |
+
- ripgrep=0.10.0=hc07d326_0
|
226 |
+
- rope=0.14.0=py_0
|
227 |
+
- ruamel_yaml=0.15.46=py37h14c3975_0
|
228 |
+
- scikit-learn=0.21.3=py37hd81dba3_0
|
229 |
+
- scipy=1.3.1=py37h7c811a0_0
|
230 |
+
- seaborn=0.9.0=py37_0
|
231 |
+
- secretstorage=3.1.1=py37_0
|
232 |
+
- send2trash=1.5.0=py37_0
|
233 |
+
- setuptools=41.4.0=py37_0
|
234 |
+
- simplegeneric=0.8.1=py37_2
|
235 |
+
- singledispatch=3.4.0.3=py37_0
|
236 |
+
- sip=4.19.8=py37hf484d3e_0
|
237 |
+
- six=1.12.0=py37_0
|
238 |
+
- snappy=1.1.7=hbae5bb6_3
|
239 |
+
- snowballstemmer=2.0.0=py_0
|
240 |
+
- sortedcollections=1.1.2=py37_0
|
241 |
+
- sortedcontainers=2.1.0=py37_0
|
242 |
+
- soupsieve=1.9.3=py37_0
|
243 |
+
- sphinx=2.2.0=py_0
|
244 |
+
- sphinxcontrib=1.0=py37_1
|
245 |
+
- sphinxcontrib-applehelp=1.0.1=py_0
|
246 |
+
- sphinxcontrib-devhelp=1.0.1=py_0
|
247 |
+
- sphinxcontrib-htmlhelp=1.0.2=py_0
|
248 |
+
- sphinxcontrib-jsmath=1.0.1=py_0
|
249 |
+
- sphinxcontrib-qthelp=1.0.2=py_0
|
250 |
+
- sphinxcontrib-serializinghtml=1.1.3=py_0
|
251 |
+
- sphinxcontrib-websupport=1.1.2=py_0
|
252 |
+
- spyder=3.3.6=py37_0
|
253 |
+
- spyder-kernels=0.5.2=py37_0
|
254 |
+
- sqlalchemy=1.3.9=py37h7b6447c_0
|
255 |
+
- sqlite=3.30.0=h7b6447c_0
|
256 |
+
- statsmodels=0.10.1=py37hdd07704_0
|
257 |
+
- sympy=1.4=py37_0
|
258 |
+
- tbb=2019.4=hfd86e86_0
|
259 |
+
- tblib=1.4.0=py_0
|
260 |
+
- terminado=0.8.2=py37_0
|
261 |
+
- testpath=0.4.2=py37_0
|
262 |
+
- tk=8.6.8=hbc83047_0
|
263 |
+
- toolz=0.10.0=py_0
|
264 |
+
- tornado=6.0.3=py37h7b6447c_0
|
265 |
+
- traitlets=4.3.3=py37_0
|
266 |
+
- unicodecsv=0.14.1=py37_0
|
267 |
+
- unixodbc=2.3.7=h14c3975_0
|
268 |
+
- wcwidth=0.1.7=py37_0
|
269 |
+
- webencodings=0.5.1=py37_1
|
270 |
+
- werkzeug=0.16.0=py_0
|
271 |
+
- wheel=0.33.6=py37_0
|
272 |
+
- widgetsnbextension=3.5.1=py37_0
|
273 |
+
- wrapt=1.11.2=py37h7b6447c_0
|
274 |
+
- wurlitzer=1.0.3=py37_0
|
275 |
+
- xlrd=1.2.0=py37_0
|
276 |
+
- xlsxwriter=1.2.1=py_0
|
277 |
+
- xlwt=1.3.0=py37_0
|
278 |
+
- xz=5.2.4=h14c3975_4
|
279 |
+
- yaml=0.1.7=had09818_2
|
280 |
+
- zeromq=4.3.1=he6710b0_3
|
281 |
+
- zict=1.0.0=py_0
|
282 |
+
- zipp=0.6.0=py_0
|
283 |
+
- zlib=1.2.11=h7b6447c_3
|
284 |
+
- zstd=1.3.7=h0b5b093_0
|
285 |
+
- pip:
|
286 |
+
- absl-py==0.12.0
|
287 |
+
- aiohttp==3.7.4.post0
|
288 |
+
- albumentations==0.5.2
|
289 |
+
- alembic==1.4.1
|
290 |
+
- arrow==0.17.0
|
291 |
+
- async-timeout==3.0.1
|
292 |
+
- b2sdk==1.4.0
|
293 |
+
- boto3==1.17.36
|
294 |
+
- botocore==1.20.36
|
295 |
+
- cachetools==4.2.1
|
296 |
+
- colour-demosaicing==0.1.6
|
297 |
+
- colour-science==0.3.16
|
298 |
+
- configparser==5.0.0
|
299 |
+
- databricks-cli==0.10.0
|
300 |
+
- docker==4.2.0
|
301 |
+
- docopt==0.6.2
|
302 |
+
- efficientnet-pytorch==0.6.3
|
303 |
+
- fsspec==0.8.7
|
304 |
+
- funcsigs==1.0.2
|
305 |
+
- gitdb==4.0.4
|
306 |
+
- gitpython==3.1.1
|
307 |
+
- google-auth==1.28.0
|
308 |
+
- google-auth-oauthlib==0.4.3
|
309 |
+
- gorilla==0.3.0
|
310 |
+
- grpcio==1.36.1
|
311 |
+
- gunicorn==20.0.4
|
312 |
+
- imgaug==0.4.0
|
313 |
+
- importlib-metadata==3.7.3
|
314 |
+
- jmespath==0.10.0
|
315 |
+
- logfury==0.1.2
|
316 |
+
- mako==1.1.2
|
317 |
+
- markdown==3.3.4
|
318 |
+
- mlflow==1.14.1
|
319 |
+
- multidict==5.1.0
|
320 |
+
- munch==2.5.0
|
321 |
+
- oauthlib==3.1.0
|
322 |
+
- opencv-python==4.5.1.48
|
323 |
+
- opencv-python-headless==4.5.1.48
|
324 |
+
- pandas==1.2.3
|
325 |
+
- pillow==8.1.2
|
326 |
+
- pipreqs==0.4.10
|
327 |
+
- plotly==4.14.3
|
328 |
+
- pretrainedmodels==0.7.4
|
329 |
+
- prettytable==2.1.0
|
330 |
+
- prometheus-flask-exporter==0.13.0
|
331 |
+
- protobuf==3.11.3
|
332 |
+
- pyasn1==0.4.8
|
333 |
+
- pyasn1-modules==0.2.8
|
334 |
+
- python-editor==1.0.4
|
335 |
+
- pytorch-lightning==1.2.5
|
336 |
+
- pywavelets==1.1.1
|
337 |
+
- querystring-parser==1.2.4
|
338 |
+
- rawpy==0.16.0
|
339 |
+
- requests-oauthlib==1.3.0
|
340 |
+
- retrying==1.3.3
|
341 |
+
- rsa==4.7.2
|
342 |
+
- s3transfer==0.3.6
|
343 |
+
- scikit-image==0.18.1
|
344 |
+
- segmentation-models-pytorch==0.1.3
|
345 |
+
- shapely==1.7.1
|
346 |
+
- simplejson==3.17.0
|
347 |
+
- smmap==3.0.2
|
348 |
+
- sqlparse==0.3.1
|
349 |
+
- tabulate==0.8.7
|
350 |
+
- tensorboard==2.4.1
|
351 |
+
- tensorboard-plugin-wit==1.8.0
|
352 |
+
- tifffile==2021.3.17
|
353 |
+
- timm==0.3.2
|
354 |
+
- torch==1.8.0
|
355 |
+
- torchmetrics==0.2.0
|
356 |
+
- torchvision==0.9.0
|
357 |
+
- tqdm==4.59.0
|
358 |
+
- typing-extensions==3.7.4.3
|
359 |
+
- urllib3==1.25.11
|
360 |
+
- websocket-client==0.57.0
|
361 |
+
- yarg==0.1.9
|
362 |
+
- yarl==1.6.3
|
363 |
+
prefix: /home/nobis/anaconda3/envs/perturbed
|
figures/ABtesting.py
ADDED
@@ -0,0 +1,831 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
from cv2 import transform
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torchvision.transforms import Compose, Normalize
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from dataset import get_dataset, Subset
|
12 |
+
from utils.base import get_mlflow_model_by_name, SmartFormatter
|
13 |
+
from processing.pipeline_numpy import RawProcessingPipeline
|
14 |
+
|
15 |
+
from utils.hendrycks_robustness import Distortions
|
16 |
+
|
17 |
+
import segmentation_models_pytorch as smp
|
18 |
+
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser(description="AB testing, Show Results", formatter_class=SmartFormatter)
|
22 |
+
|
23 |
+
# Select experiment
|
24 |
+
parser.add_argument("--mode", type=str, default="ABShowImages", choices=('ABMakeTable', 'ABShowTable', 'ABShowImages', 'ABShowAllImages', 'CMakeTable', 'CShowTable', 'CShowImages', 'CShowAllImages'),
|
25 |
+
help='R|Choose operation to compute. \n'
|
26 |
+
'A) Lens2Logit image generation: \n '
|
27 |
+
'ABMakeTable: Compute cross-validation metrics results \n '
|
28 |
+
'ABShowTable: Plot cross-validation results on a table \n '
|
29 |
+
'ABShowImages: Choose a training and testing image to compare different pipelines \n '
|
30 |
+
'ABShowAllImages: Plot all possible pipelines \n'
|
31 |
+
'B) Hendrycks Perturbations, C-type dataset: \n '
|
32 |
+
'CMakeTable: For each pipeline, it computes cross-validation metrics for different perturbations \n '
|
33 |
+
'CShowTable: Plot metrics for different pipelines and perturbations \n '
|
34 |
+
'CShowImages: Plot an image with a selected a pipeline and perturbation\n '
|
35 |
+
'CShowAllImages: Plot all possible perturbations for a fixed pipeline')
|
36 |
+
|
37 |
+
parser.add_argument("--dataset_name", type=str, default='Microscopy',
|
38 |
+
choices=['Microscopy', 'Drone', 'DroneSegmentation'], help='Choose dataset')
|
39 |
+
parser.add_argument("--augmentation", type=str, default='weak',
|
40 |
+
choices=['none', 'weak', 'strong'], help='Choose augmentation')
|
41 |
+
parser.add_argument("--N_runs", type=int, default=5, help='Number of k-fold splitting used in the training')
|
42 |
+
parser.add_argument("--download_model", default=False, action='store_true', help='Download Models in cache')
|
43 |
+
|
44 |
+
# Select pipelines
|
45 |
+
parser.add_argument("--dm_train", type=str, default='bilinear', choices=('bilinear', 'malvar2004',
|
46 |
+
'menon2007'), help='Choose demosaicing for training processing model')
|
47 |
+
parser.add_argument("--s_train", type=str, default='sharpening_filter', choices=('sharpening_filter',
|
48 |
+
'unsharp_masking'), help='Choose sharpening for training processing model')
|
49 |
+
parser.add_argument("--dn_train", type=str, default='gaussian_denoising', choices=('gaussian_denoising',
|
50 |
+
'median_denoising'), help='Choose denoising for training processing model')
|
51 |
+
parser.add_argument("--dm_test", type=str, default='bilinear', choices=('bilinear', 'malvar2004',
|
52 |
+
'menon2007'), help='Choose demosaicing for testing processing model')
|
53 |
+
parser.add_argument("--s_test", type=str, default='sharpening_filter', choices=('sharpening_filter',
|
54 |
+
'unsharp_masking'), help='Choose sharpening for testing processing model')
|
55 |
+
parser.add_argument("--dn_test", type=str, default='gaussian_denoising', choices=('gaussian_denoising',
|
56 |
+
'median_denoising'), help='Choose denoising for testing processing model')
|
57 |
+
|
58 |
+
# Select Ctest parameters
|
59 |
+
parser.add_argument("--transform", type=str, default='identity', choices=('identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
60 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform'), help='Choose transformation to show for Ctesting')
|
61 |
+
parser.add_argument("--severity", type=int, default=1, choices=(1, 2, 3, 4, 5), help='Choose severity for Ctesting')
|
62 |
+
|
63 |
+
args = parser.parse_args()
|
64 |
+
|
65 |
+
|
66 |
+
class metrics:
|
67 |
+
def __init__(self, confusion_matrix):
|
68 |
+
self.cm = confusion_matrix
|
69 |
+
self.N_classes = len(confusion_matrix)
|
70 |
+
|
71 |
+
def accuracy(self):
|
72 |
+
Tp = torch.diagonal(self.cm, 0).sum()
|
73 |
+
N_elements = torch.sum(self.cm)
|
74 |
+
return Tp / N_elements
|
75 |
+
|
76 |
+
def precision(self):
|
77 |
+
Tp_Fp = torch.sum(self.cm, 1)
|
78 |
+
Tp_Fp[Tp_Fp == 0] = 1
|
79 |
+
return torch.diagonal(self.cm, 0) / Tp_Fp
|
80 |
+
|
81 |
+
def recall(self):
|
82 |
+
Tp_Fn = torch.sum(self.cm, 0)
|
83 |
+
Tp_Fn[Tp_Fn == 0] = 1
|
84 |
+
return torch.diagonal(self.cm, 0) / Tp_Fn
|
85 |
+
|
86 |
+
def f1_score(self):
|
87 |
+
prod = (self.precision() * self.recall())
|
88 |
+
sum = (self.precision() + self.recall())
|
89 |
+
sum[sum == 0.] = 1.
|
90 |
+
return 2 * (prod / sum)
|
91 |
+
|
92 |
+
def over_N_runs(ms, N_runs):
|
93 |
+
m, m2 = 0, 0
|
94 |
+
|
95 |
+
for i in ms:
|
96 |
+
m += i
|
97 |
+
mu = m / N_runs
|
98 |
+
|
99 |
+
for i in ms:
|
100 |
+
m2 += (i - mu)**2
|
101 |
+
|
102 |
+
sigma = torch.sqrt(m2 / (N_runs - 1))
|
103 |
+
|
104 |
+
return mu.tolist(), sigma.tolist()
|
105 |
+
|
106 |
+
|
107 |
+
class ABtesting:
|
108 |
+
def __init__(self,
|
109 |
+
dataset_name: str,
|
110 |
+
augmentation: str,
|
111 |
+
dm_train: str,
|
112 |
+
s_train: str,
|
113 |
+
dn_train: str,
|
114 |
+
dm_test: str,
|
115 |
+
s_test: str,
|
116 |
+
dn_test: str,
|
117 |
+
N_runs: int,
|
118 |
+
severity=1,
|
119 |
+
transform='identity',
|
120 |
+
download_model=False):
|
121 |
+
self.experiment_name = 'ABtesting'
|
122 |
+
self.dataset_name = dataset_name
|
123 |
+
self.augmentation = augmentation
|
124 |
+
self.dm_train = dm_train
|
125 |
+
self.s_train = s_train
|
126 |
+
self.dn_train = dn_train
|
127 |
+
self.dm_test = dm_test
|
128 |
+
self.s_test = s_test
|
129 |
+
self.dn_test = dn_test
|
130 |
+
self.N_runs = N_runs
|
131 |
+
self.severity = severity
|
132 |
+
self.transform = transform
|
133 |
+
self.download_model = download_model
|
134 |
+
|
135 |
+
def static_pip_val(self, debayer=None, sharpening=None, denoising=None, severity=None, transform=None, plot_mode=False):
|
136 |
+
|
137 |
+
if debayer == None:
|
138 |
+
debayer = self.dm_test
|
139 |
+
if sharpening == None:
|
140 |
+
sharpening = self.s_test
|
141 |
+
if denoising == None:
|
142 |
+
denoising = self.dn_test
|
143 |
+
if severity == None:
|
144 |
+
severity = self.severity
|
145 |
+
if transform == None:
|
146 |
+
transform = self.transform
|
147 |
+
|
148 |
+
dataset = get_dataset(self.dataset_name)
|
149 |
+
|
150 |
+
if self.dataset_name == "Drone" or self.dataset_name == "DroneSegmentation":
|
151 |
+
mean = torch.tensor([0.35, 0.36, 0.35])
|
152 |
+
std = torch.tensor([0.12, 0.11, 0.12])
|
153 |
+
elif self.dataset_name == "Microscopy":
|
154 |
+
mean = torch.tensor([0.91, 0.84, 0.94])
|
155 |
+
std = torch.tensor([0.08, 0.12, 0.05])
|
156 |
+
|
157 |
+
if not plot_mode:
|
158 |
+
dataset.transform = Compose([RawProcessingPipeline(
|
159 |
+
camera_parameters=dataset.camera_parameters,
|
160 |
+
debayer=debayer,
|
161 |
+
sharpening=sharpening,
|
162 |
+
denoising=denoising,
|
163 |
+
), Distortions(severity=severity, transform=transform),
|
164 |
+
Normalize(mean, std)])
|
165 |
+
else:
|
166 |
+
dataset.transform = Compose([RawProcessingPipeline(
|
167 |
+
camera_parameters=dataset.camera_parameters,
|
168 |
+
debayer=debayer,
|
169 |
+
sharpening=sharpening,
|
170 |
+
denoising=denoising,
|
171 |
+
), Distortions(severity=severity, transform=transform)])
|
172 |
+
|
173 |
+
return dataset
|
174 |
+
|
175 |
+
def ABclassification(self):
|
176 |
+
|
177 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
178 |
+
|
179 |
+
parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
|
180 |
+
|
181 |
+
print(
|
182 |
+
f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
|
183 |
+
print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
|
184 |
+
|
185 |
+
accuracies, precisions, recalls, f1_scores = [], [], [], []
|
186 |
+
|
187 |
+
os.system('rm -r /tmp/py*')
|
188 |
+
|
189 |
+
for N_run in range(self.N_runs):
|
190 |
+
|
191 |
+
print(f"Evaluating Run {N_run}")
|
192 |
+
|
193 |
+
run_name = parent_run_name + '_' + str(N_run)
|
194 |
+
|
195 |
+
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
|
196 |
+
download_model=self.download_model)
|
197 |
+
|
198 |
+
dataset = self.static_pip_val()
|
199 |
+
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
200 |
+
valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
|
201 |
+
|
202 |
+
model.eval()
|
203 |
+
|
204 |
+
len_classes = len(dataset.classes)
|
205 |
+
confusion_matrix = torch.zeros((len_classes, len_classes))
|
206 |
+
|
207 |
+
for img, label in valid_loader:
|
208 |
+
|
209 |
+
prediction = model(img.to(DEVICE)).detach().cpu()
|
210 |
+
prediction = torch.argmax(prediction, dim=1)
|
211 |
+
confusion_matrix[label, prediction] += 1 # Real value rows, Declared columns
|
212 |
+
|
213 |
+
m = metrics(confusion_matrix)
|
214 |
+
|
215 |
+
accuracies.append(m.accuracy())
|
216 |
+
precisions.append(m.precision())
|
217 |
+
recalls.append(m.recall())
|
218 |
+
f1_scores.append(m.f1_score())
|
219 |
+
|
220 |
+
os.system('rm -r /tmp/t*')
|
221 |
+
|
222 |
+
accuracy = metrics.over_N_runs(accuracies, self.N_runs)
|
223 |
+
precision = metrics.over_N_runs(precisions, self.N_runs)
|
224 |
+
recall = metrics.over_N_runs(recalls, self.N_runs)
|
225 |
+
f1_score = metrics.over_N_runs(f1_scores, self.N_runs)
|
226 |
+
return dataset.classes, accuracy, precision, recall, f1_score
|
227 |
+
|
228 |
+
def ABsegmentation(self):
|
229 |
+
|
230 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
231 |
+
|
232 |
+
parent_run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}"
|
233 |
+
|
234 |
+
print(
|
235 |
+
f'\nTraining pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_train}, Sharpening: {self.s_train}, Denoiser: {self.dn_train} \n')
|
236 |
+
print(f'\nTesting pipeline:\n Dataset: {self.dataset_name}, Augmentation: {self.augmentation} \n Debayer: {self.dm_test}, Sharpening: {self.s_test}, Denoiser: {self.dn_test} \n Transform: {self.transform}, Severity: {self.severity}\n')
|
237 |
+
|
238 |
+
IoUs = []
|
239 |
+
|
240 |
+
os.system('rm -r /tmp/py*')
|
241 |
+
|
242 |
+
for N_run in range(self.N_runs):
|
243 |
+
|
244 |
+
print(f"Evaluating Run {N_run}")
|
245 |
+
|
246 |
+
run_name = parent_run_name + '_' + str(N_run)
|
247 |
+
|
248 |
+
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name,
|
249 |
+
download_model=self.download_model)
|
250 |
+
|
251 |
+
dataset = self.static_pip_val()
|
252 |
+
|
253 |
+
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
254 |
+
valid_loader = DataLoader(valid_set, batch_size=1, num_workers=16, shuffle=False)
|
255 |
+
|
256 |
+
model.eval()
|
257 |
+
|
258 |
+
IoU = 0
|
259 |
+
|
260 |
+
for img, label in valid_loader:
|
261 |
+
|
262 |
+
prediction = model(img.to(DEVICE)).detach().cpu()
|
263 |
+
prediction = F.logsigmoid(prediction).exp().squeeze()
|
264 |
+
IoU += smp.utils.metrics.IoU()(prediction, label)
|
265 |
+
|
266 |
+
IoU = IoU / len(valid_loader)
|
267 |
+
IoUs.append(IoU.item())
|
268 |
+
|
269 |
+
os.system('rm -r /tmp/t*')
|
270 |
+
|
271 |
+
IoU = metrics.over_N_runs(torch.tensor(IoUs), self.N_runs)
|
272 |
+
return IoU
|
273 |
+
|
274 |
+
def ABShowImages(self):
|
275 |
+
|
276 |
+
path = 'results/ABtesting/imgs/'
|
277 |
+
if not os.path.exists(path):
|
278 |
+
os.makedirs(path)
|
279 |
+
|
280 |
+
path = os.path.join(
|
281 |
+
path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.dm_test[:2]}{self.s_test[0]}{self.dn_test[:2]}')
|
282 |
+
|
283 |
+
if not os.path.exists(path):
|
284 |
+
os.makedirs(path)
|
285 |
+
|
286 |
+
run_name = f"{self.dataset_name}_{self.dm_train}_{self.s_train}_{self.dn_train}_{self.augmentation}" + \
|
287 |
+
'_' + str(0)
|
288 |
+
|
289 |
+
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=self.download_model)
|
290 |
+
|
291 |
+
model.augmentation = None
|
292 |
+
|
293 |
+
for t in ([self.dm_train, self.s_train, self.dn_train, 'train_img'],
|
294 |
+
[self.dm_test, self.s_test, self.dn_test, 'test_img']):
|
295 |
+
|
296 |
+
debayer, sharpening, denoising, img_type = t[0], t[1], t[2], t[3]
|
297 |
+
|
298 |
+
dataset = self.static_pip_val(debayer=debayer, sharpening=sharpening, denoising=denoising, plot_mode=True)
|
299 |
+
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
300 |
+
|
301 |
+
img, _ = next(iter(valid_set))
|
302 |
+
|
303 |
+
plt.figure()
|
304 |
+
plt.imshow(img.permute(1, 2, 0))
|
305 |
+
if img_type == 'train_img':
|
306 |
+
plt.title('Train Image')
|
307 |
+
plt.savefig(os.path.join(path, f'img_train.png'))
|
308 |
+
imgA = img
|
309 |
+
else:
|
310 |
+
plt.title('Test Image')
|
311 |
+
plt.savefig(os.path.join(path, f'img_test.png'))
|
312 |
+
|
313 |
+
for c, color in enumerate(['Red', 'Green', 'Blue']):
|
314 |
+
diff = torch.abs(imgA - img)
|
315 |
+
plt.figure()
|
316 |
+
# plt.imshow(diff.permute(1,2,0))
|
317 |
+
plt.imshow(diff[c, 50:200, 50:200], cmap=f'{color}s')
|
318 |
+
plt.title(f'|Train Image - Test Image| - {color}')
|
319 |
+
plt.colorbar()
|
320 |
+
plt.savefig(os.path.join(path, f'diff_{color}.png'))
|
321 |
+
plt.figure()
|
322 |
+
diff[diff == 0.] = 1e-5
|
323 |
+
# plt.imshow(torch.log(diff.permute(1,2,0)))
|
324 |
+
plt.imshow(torch.log(diff)[c])
|
325 |
+
plt.title(f'log(|Train Image - Test Image|) - color')
|
326 |
+
plt.colorbar()
|
327 |
+
plt.savefig(os.path.join(path, f'logdiff_{color}.png'))
|
328 |
+
|
329 |
+
if self.dataset_name == 'DroneSegmentation':
|
330 |
+
plt.figure()
|
331 |
+
plt.imshow(model(img[None].cuda()).detach().cpu().squeeze())
|
332 |
+
if img_type == 'train_img':
|
333 |
+
plt.savefig(os.path.join(path, f'mask_train.png'))
|
334 |
+
else:
|
335 |
+
plt.savefig(os.path.join(path, f'mask_test.png'))
|
336 |
+
|
337 |
+
def ABShowAllImages(self):
|
338 |
+
if not os.path.exists('results/ABtesting'):
|
339 |
+
os.makedirs('results/ABtesting')
|
340 |
+
|
341 |
+
demosaicings = ['bilinear', 'malvar2004', 'menon2007']
|
342 |
+
sharpenings = ['sharpening_filter', 'unsharp_masking']
|
343 |
+
denoisings = ['median_denoising', 'gaussian_denoising']
|
344 |
+
|
345 |
+
fig = plt.figure()
|
346 |
+
columns = 4
|
347 |
+
rows = 3
|
348 |
+
|
349 |
+
i = 1
|
350 |
+
|
351 |
+
for dm in demosaicings:
|
352 |
+
for s in sharpenings:
|
353 |
+
for dn in denoisings:
|
354 |
+
|
355 |
+
dataset = self.static_pip_val(self.dm_test, self.s_test,
|
356 |
+
self.dn_test, plot_mode=True)
|
357 |
+
|
358 |
+
img, _ = dataset[0]
|
359 |
+
|
360 |
+
fig.add_subplot(rows, columns, i)
|
361 |
+
plt.imshow(img.permute(1, 2, 0))
|
362 |
+
plt.title(f'{dm}\n{s}\n{dn}', fontsize=8)
|
363 |
+
plt.xticks([])
|
364 |
+
plt.yticks([])
|
365 |
+
plt.tight_layout()
|
366 |
+
|
367 |
+
i += 1
|
368 |
+
|
369 |
+
plt.show()
|
370 |
+
plt.savefig(f'results/ABtesting/ABpipelines.png')
|
371 |
+
|
372 |
+
def CShowImages(self):
|
373 |
+
|
374 |
+
path = 'results/Ctesting/imgs/'
|
375 |
+
if not os.path.exists(path):
|
376 |
+
os.makedirs(path)
|
377 |
+
|
378 |
+
run_name = f"{self.dataset_name}_{self.dm_test}_{self.s_test}_{self.dn_test}_{self.augmentation}" + '_' + str(0)
|
379 |
+
|
380 |
+
state_dict, model = get_mlflow_model_by_name(self.experiment_name, run_name, download_model=True)
|
381 |
+
|
382 |
+
model.augmentation = None
|
383 |
+
|
384 |
+
dataset = self.static_pip_val(self.dm_test, self.s_test, self.dn_test,
|
385 |
+
self.severity, self.transform, plot_mode=True)
|
386 |
+
valid_set = Subset(dataset, indices=state_dict['valid_indices'])
|
387 |
+
|
388 |
+
img, _ = next(iter(valid_set))
|
389 |
+
|
390 |
+
plt.figure()
|
391 |
+
plt.imshow(img.permute(1, 2, 0))
|
392 |
+
plt.savefig(os.path.join(
|
393 |
+
path, f'{self.dataset_name}_{self.augmentation}_{self.dm_train[:2]}{self.s_train[0]}{self.dn_train[:2]}_{self.transform}_sev{self.severity}'))
|
394 |
+
|
395 |
+
def CShowAllImages(self):
|
396 |
+
if not os.path.exists('results/Cimages'):
|
397 |
+
os.makedirs('results/Cimages')
|
398 |
+
|
399 |
+
transforms = ['identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
400 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
|
401 |
+
|
402 |
+
for i, t in enumerate(transforms):
|
403 |
+
|
404 |
+
fig = plt.figure(figsize=(10, 6))
|
405 |
+
columns = 5
|
406 |
+
rows = 1
|
407 |
+
|
408 |
+
for sev in range(1, 6):
|
409 |
+
|
410 |
+
dataset = self.static_pip_val(severity=sev, transform=t, plot_mode=True)
|
411 |
+
|
412 |
+
img, _ = dataset[0]
|
413 |
+
|
414 |
+
fig.add_subplot(rows, columns, sev)
|
415 |
+
plt.imshow(img.permute(1, 2, 0))
|
416 |
+
plt.title(f'Severity: {sev}')
|
417 |
+
plt.xticks([])
|
418 |
+
plt.yticks([])
|
419 |
+
plt.tight_layout()
|
420 |
+
|
421 |
+
if '_' in t:
|
422 |
+
t = t.replace('_', ' ')
|
423 |
+
t = t[0].upper() + t[1:]
|
424 |
+
|
425 |
+
fig.suptitle(f'{t}', x=0.5, y=0.8, fontsize=24)
|
426 |
+
plt.show()
|
427 |
+
plt.savefig(f'results/Cimages/{i+1}_{t.lower()}.png')
|
428 |
+
|
429 |
+
|
430 |
+
def ABMakeTable(dataset_name: str, augmentation: str,
|
431 |
+
N_runs: int, download_model: bool):
|
432 |
+
|
433 |
+
demosaicings = ['bilinear', 'malvar2004', 'menon2007']
|
434 |
+
sharpenings = ['sharpening_filter', 'unsharp_masking']
|
435 |
+
denoisings = ['median_denoising', 'gaussian_denoising']
|
436 |
+
|
437 |
+
path = 'results/ABtesting/tables'
|
438 |
+
if not os.path.exists(path):
|
439 |
+
os.makedirs(path)
|
440 |
+
|
441 |
+
runs = {}
|
442 |
+
i = 0
|
443 |
+
|
444 |
+
for dm_train in demosaicings:
|
445 |
+
for s_train in sharpenings:
|
446 |
+
for dn_train in denoisings:
|
447 |
+
for dm_test in demosaicings:
|
448 |
+
for s_test in sharpenings:
|
449 |
+
for dn_test in denoisings:
|
450 |
+
train_pip = [dm_train, s_train, dn_train]
|
451 |
+
test_pip = [dm_test, s_test, dn_test]
|
452 |
+
runs[f'run{i}'] = {
|
453 |
+
'dataset': dataset_name,
|
454 |
+
'augmentation': augmentation,
|
455 |
+
'train_pip': train_pip,
|
456 |
+
'test_pip': test_pip,
|
457 |
+
'N_runs': N_runs
|
458 |
+
}
|
459 |
+
ABclass = ABtesting(
|
460 |
+
dataset_name=dataset_name,
|
461 |
+
augmentation=augmentation,
|
462 |
+
dm_train=dm_train,
|
463 |
+
s_train=s_train,
|
464 |
+
dn_train=dn_train,
|
465 |
+
dm_test=dm_test,
|
466 |
+
s_test=s_test,
|
467 |
+
dn_test=dn_test,
|
468 |
+
N_runs=N_runs,
|
469 |
+
download_model=download_model
|
470 |
+
)
|
471 |
+
|
472 |
+
if dataset_name == 'DroneSegmentation':
|
473 |
+
IoU = ABclass.ABsegmentation()
|
474 |
+
runs[f'run{i}']['IoU'] = IoU
|
475 |
+
else:
|
476 |
+
classes, accuracy, precision, recall, f1_score = ABclass.ABclassification()
|
477 |
+
runs[f'run{i}']['classes'] = classes
|
478 |
+
runs[f'run{i}']['accuracy'] = accuracy
|
479 |
+
runs[f'run{i}']['precision'] = precision
|
480 |
+
runs[f'run{i}']['recall'] = recall
|
481 |
+
runs[f'run{i}']['f1_score'] = f1_score
|
482 |
+
|
483 |
+
with open(os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt'), 'w') as outfile:
|
484 |
+
json.dump(runs, outfile)
|
485 |
+
|
486 |
+
i += 1
|
487 |
+
|
488 |
+
|
489 |
+
def ABShowTable(dataset_name: str, augmentation: str):
|
490 |
+
|
491 |
+
path = 'results/ABtesting/tables'
|
492 |
+
assert os.path.exists(path), 'No tables to plot'
|
493 |
+
|
494 |
+
json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
|
495 |
+
|
496 |
+
with open(json_file, 'r') as run_file:
|
497 |
+
runs = json.load(run_file)
|
498 |
+
|
499 |
+
metrics = torch.zeros((2, 12, 12))
|
500 |
+
classes = []
|
501 |
+
|
502 |
+
i, j = 0, 0
|
503 |
+
|
504 |
+
for r in range(len(runs)):
|
505 |
+
|
506 |
+
run = runs['run' + str(r)]
|
507 |
+
if dataset_name == 'DroneSegmentation':
|
508 |
+
acc = run['IoU']
|
509 |
+
else:
|
510 |
+
acc = run['accuracy']
|
511 |
+
if len(classes) < 12:
|
512 |
+
class_list = run['test_pip']
|
513 |
+
class_name = f'{class_list[0][:2]},{class_list[1][:1]},{class_list[2][:2]}'
|
514 |
+
classes.append(class_name)
|
515 |
+
mu, sigma = round(acc[0], 4), round(acc[1], 4)
|
516 |
+
|
517 |
+
metrics[0, j, i] = mu
|
518 |
+
metrics[1, j, i] = sigma
|
519 |
+
|
520 |
+
i += 1
|
521 |
+
|
522 |
+
if i == 12:
|
523 |
+
i = 0
|
524 |
+
j += 1
|
525 |
+
|
526 |
+
differences = torch.zeros_like(metrics)
|
527 |
+
|
528 |
+
diag_mu = torch.diagonal(metrics[0], 0)
|
529 |
+
diag_sigma = torch.diagonal(metrics[1], 0)
|
530 |
+
|
531 |
+
for r in range(len(metrics[0])):
|
532 |
+
differences[0, r] = diag_mu[r] - metrics[0, r]
|
533 |
+
differences[1, r] = torch.sqrt(metrics[1, r]**2 + diag_sigma[r]**2)
|
534 |
+
|
535 |
+
# Plot with scatter
|
536 |
+
|
537 |
+
for i, img in enumerate([metrics, differences]):
|
538 |
+
|
539 |
+
x, y = torch.arange(12), torch.arange(12)
|
540 |
+
x, y = torch.meshgrid(x, y)
|
541 |
+
|
542 |
+
if i == 0:
|
543 |
+
vmin = max(0.65, round(img[0].min().item(), 2))
|
544 |
+
vmax = round(img[0].max().item(), 2)
|
545 |
+
step = 0.02
|
546 |
+
elif i == 1:
|
547 |
+
vmin = round(img[0].min().item(), 2)
|
548 |
+
if augmentation == 'none':
|
549 |
+
vmax = min(0.15, round(img[0].max().item(), 2))
|
550 |
+
if augmentation == 'weak':
|
551 |
+
vmax = min(0.08, round(img[0].max().item(), 2))
|
552 |
+
if augmentation == 'strong':
|
553 |
+
vmax = min(0.05, round(img[0].max().item(), 2))
|
554 |
+
step = 0.01
|
555 |
+
|
556 |
+
vmin = int(vmin / step) * step
|
557 |
+
vmax = int(vmax / step) * step
|
558 |
+
|
559 |
+
fig = plt.figure(figsize=(10, 6.2))
|
560 |
+
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
|
561 |
+
marker_size = 350
|
562 |
+
plt.scatter(x, y, c=torch.rot90(img[1][x, y], -1, [0, 1]), vmin=0.,
|
563 |
+
vmax=img[1].max(), cmap='viridis', s=marker_size * 2, marker='s')
|
564 |
+
ticks = torch.arange(0., img[1].max(), 0.03).tolist()
|
565 |
+
ticks = [round(tick, 2) for tick in ticks]
|
566 |
+
cba = plt.colorbar(pad=0.06)
|
567 |
+
cba.set_ticks(ticks)
|
568 |
+
cba.ax.set_yticklabels(ticks)
|
569 |
+
# cmap = plt.cm.get_cmap('tab20c').reversed()
|
570 |
+
cmap = plt.cm.get_cmap('Reds')
|
571 |
+
plt.scatter(x, y, c=torch.rot90(img[0][x, y], -1, [0, 1]), vmin=vmin,
|
572 |
+
vmax=vmax, cmap=cmap, s=marker_size, marker='s')
|
573 |
+
ticks = torch.arange(vmin, vmax, step).tolist()
|
574 |
+
ticks = [round(tick, 2) for tick in ticks]
|
575 |
+
if ticks[-1] != vmax:
|
576 |
+
ticks.append(vmax)
|
577 |
+
cbb = plt.colorbar(pad=0.06)
|
578 |
+
cbb.set_ticks(ticks)
|
579 |
+
if i == 0:
|
580 |
+
ticks[0] = f'<{str(ticks[0])}'
|
581 |
+
elif i == 1:
|
582 |
+
ticks[-1] = f'>{str(ticks[-1])}'
|
583 |
+
cbb.ax.set_yticklabels(ticks)
|
584 |
+
for x in range(12):
|
585 |
+
for y in range(12):
|
586 |
+
txt = round(torch.rot90(img[0], -1, [0, 1])[x, y].item(), 2)
|
587 |
+
if str(txt) == '-0.0':
|
588 |
+
txt = '0.00'
|
589 |
+
elif str(txt) == '0.0':
|
590 |
+
txt = '0.00'
|
591 |
+
elif len(str(txt)) == 3:
|
592 |
+
txt = str(txt) + '0'
|
593 |
+
else:
|
594 |
+
txt = str(txt)
|
595 |
+
|
596 |
+
plt.text(x - 0.25, y - 0.1, txt, color='black', fontsize='x-small')
|
597 |
+
|
598 |
+
ax.set_xticks(torch.linspace(0, 11, 12))
|
599 |
+
ax.set_xticklabels(classes)
|
600 |
+
ax.set_yticks(torch.linspace(0, 11, 12))
|
601 |
+
classes.reverse()
|
602 |
+
ax.set_yticklabels(classes)
|
603 |
+
classes.reverse()
|
604 |
+
plt.xticks(rotation=45)
|
605 |
+
plt.yticks(rotation=45)
|
606 |
+
cba.set_label('Standard Deviation')
|
607 |
+
plt.xlabel("Test pipelines")
|
608 |
+
plt.ylabel("Train pipelines")
|
609 |
+
plt.title(f'Dataset: {dataset_name}, Augmentation: {augmentation}')
|
610 |
+
if i == 0:
|
611 |
+
if dataset_name == 'DroneSegmentation':
|
612 |
+
cbb.set_label('IoU')
|
613 |
+
plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_IoU.png"))
|
614 |
+
else:
|
615 |
+
cbb.set_label('Accuracy')
|
616 |
+
plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_accuracies.png"))
|
617 |
+
elif i == 1:
|
618 |
+
if dataset_name == 'DroneSegmentation':
|
619 |
+
cbb.set_label('IoU_d-IoU')
|
620 |
+
else:
|
621 |
+
cbb.set_label('Accuracy_d - Accuracy')
|
622 |
+
plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_differences.png"))
|
623 |
+
|
624 |
+
|
625 |
+
def CMakeTable(dataset_name: str, augmentation: str, severity: int, N_runs: int, download_model: bool):
|
626 |
+
|
627 |
+
path = 'results/Ctesting/tables'
|
628 |
+
if not os.path.exists(path):
|
629 |
+
os.makedirs(path)
|
630 |
+
|
631 |
+
demosaicings = ['bilinear', 'malvar2004', 'menon2007']
|
632 |
+
sharpenings = ['sharpening_filter', 'unsharp_masking']
|
633 |
+
denoisings = ['median_denoising', 'gaussian_denoising']
|
634 |
+
|
635 |
+
transformations = ['identity', 'gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
636 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
|
637 |
+
|
638 |
+
runs = {}
|
639 |
+
i = 0
|
640 |
+
|
641 |
+
for dm in demosaicings:
|
642 |
+
for s in sharpenings:
|
643 |
+
for dn in denoisings:
|
644 |
+
for t in transformations:
|
645 |
+
pip = [dm, s, dn]
|
646 |
+
runs[f'run{i}'] = {
|
647 |
+
'dataset': dataset_name,
|
648 |
+
'augmentation': augmentation,
|
649 |
+
'pipeline': pip,
|
650 |
+
'N_runs': N_runs,
|
651 |
+
'transform': t,
|
652 |
+
'severity': severity,
|
653 |
+
}
|
654 |
+
ABclass = ABtesting(
|
655 |
+
dataset_name=dataset_name,
|
656 |
+
augmentation=augmentation,
|
657 |
+
dm_train=dm,
|
658 |
+
s_train=s,
|
659 |
+
dn_train=dn,
|
660 |
+
dm_test=dm,
|
661 |
+
s_test=s,
|
662 |
+
dn_test=dn,
|
663 |
+
severity=severity,
|
664 |
+
transform=t,
|
665 |
+
N_runs=N_runs,
|
666 |
+
download_model=download_model
|
667 |
+
)
|
668 |
+
|
669 |
+
if dataset_name == 'DroneSegmentation':
|
670 |
+
IoU = ABclass.ABsegmentation()
|
671 |
+
runs[f'run{i}']['IoU'] = IoU
|
672 |
+
else:
|
673 |
+
classes, accuracy, precision, recall, f1_score = ABclass.ABclassification()
|
674 |
+
runs[f'run{i}']['classes'] = classes
|
675 |
+
runs[f'run{i}']['accuracy'] = accuracy
|
676 |
+
runs[f'run{i}']['precision'] = precision
|
677 |
+
runs[f'run{i}']['recall'] = recall
|
678 |
+
runs[f'run{i}']['f1_score'] = f1_score
|
679 |
+
|
680 |
+
with open(os.path.join(path, f'{dataset_name}_{augmentation}_runs.json'), 'w') as outfile:
|
681 |
+
json.dump(runs, outfile)
|
682 |
+
|
683 |
+
i += 1
|
684 |
+
|
685 |
+
|
686 |
+
def CShowTable(dataset_name, augmentation):
|
687 |
+
|
688 |
+
path = 'results/Ctesting/tables'
|
689 |
+
assert os.path.exists(path), 'No tables to plot'
|
690 |
+
|
691 |
+
json_file = os.path.join(path, f'{dataset_name}_{augmentation}_runs.txt')
|
692 |
+
|
693 |
+
transforms = ['identity', 'gauss_noise', 'shot', 'impulse', 'speckle',
|
694 |
+
'gauss_blur', 'zoom', 'contrast', 'brightness', 'saturate', 'elastic']
|
695 |
+
|
696 |
+
pip = []
|
697 |
+
|
698 |
+
demosaicings = ['bilinear', 'malvar2004', 'menon2007']
|
699 |
+
sharpenings = ['sharpening_filter', 'unsharp_masking']
|
700 |
+
denoisings = ['median_denoising', 'gaussian_denoising']
|
701 |
+
|
702 |
+
for dm in demosaicings:
|
703 |
+
for s in sharpenings:
|
704 |
+
for dn in denoisings:
|
705 |
+
pip.append(f'{dm[:2]},{s[0]},{dn[2]}')
|
706 |
+
|
707 |
+
with open(json_file, 'r') as run_file:
|
708 |
+
runs = json.load(run_file)
|
709 |
+
|
710 |
+
metrics = torch.zeros((2, len(pip), len(transforms)))
|
711 |
+
|
712 |
+
i, j = 0, 0
|
713 |
+
|
714 |
+
for r in range(len(runs)):
|
715 |
+
|
716 |
+
run = runs['run' + str(r)]
|
717 |
+
if dataset_name == 'DroneSegmentation':
|
718 |
+
acc = run['IoU']
|
719 |
+
else:
|
720 |
+
acc = run['accuracy']
|
721 |
+
mu, sigma = round(acc[0], 4), round(acc[1], 4)
|
722 |
+
|
723 |
+
metrics[0, j, i] = mu
|
724 |
+
metrics[1, j, i] = sigma
|
725 |
+
|
726 |
+
i += 1
|
727 |
+
|
728 |
+
if i == len(transforms):
|
729 |
+
i = 0
|
730 |
+
j += 1
|
731 |
+
|
732 |
+
# Plot with scatter
|
733 |
+
|
734 |
+
img = metrics
|
735 |
+
|
736 |
+
vmin = 0.
|
737 |
+
vmax = 1.
|
738 |
+
|
739 |
+
x, y = torch.arange(12), torch.arange(11)
|
740 |
+
x, y = torch.meshgrid(x, y)
|
741 |
+
|
742 |
+
fig = plt.figure(figsize=(10, 6.2))
|
743 |
+
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
|
744 |
+
marker_size = 350
|
745 |
+
plt.scatter(x, y, c=torch.rot90(img[1][x, y], -1, [0, 1]), vmin=0.,
|
746 |
+
vmax=img[1].max(), cmap='viridis', s=marker_size * 2, marker='s')
|
747 |
+
ticks = torch.arange(0., img[1].max(), 0.03).tolist()
|
748 |
+
ticks = [round(tick, 2) for tick in ticks]
|
749 |
+
cba = plt.colorbar(pad=0.06)
|
750 |
+
cba.set_ticks(ticks)
|
751 |
+
cba.ax.set_yticklabels(ticks)
|
752 |
+
# cmap = plt.cm.get_cmap('tab20c').reversed()
|
753 |
+
cmap = plt.cm.get_cmap('Reds')
|
754 |
+
plt.scatter(x, y, c=torch.rot90(img[0][x, y], -1, [0, 1]), vmin=vmin,
|
755 |
+
vmax=vmax, cmap=cmap, s=marker_size, marker='s')
|
756 |
+
ticks = torch.arange(vmin, vmax, step).tolist()
|
757 |
+
ticks = [round(tick, 2) for tick in ticks]
|
758 |
+
if ticks[-1] != vmax:
|
759 |
+
ticks.append(vmax)
|
760 |
+
cbb = plt.colorbar(pad=0.06)
|
761 |
+
cbb.set_ticks(ticks)
|
762 |
+
if i == 0:
|
763 |
+
ticks[0] = f'<{str(ticks[0])}'
|
764 |
+
elif i == 1:
|
765 |
+
ticks[-1] = f'>{str(ticks[-1])}'
|
766 |
+
cbb.ax.set_yticklabels(ticks)
|
767 |
+
for x in range(12):
|
768 |
+
for y in range(12):
|
769 |
+
txt = round(torch.rot90(img[0], -1, [0, 1])[x, y].item(), 2)
|
770 |
+
if str(txt) == '-0.0':
|
771 |
+
txt = '0.00'
|
772 |
+
elif str(txt) == '0.0':
|
773 |
+
txt = '0.00'
|
774 |
+
elif len(str(txt)) == 3:
|
775 |
+
txt = str(txt) + '0'
|
776 |
+
else:
|
777 |
+
txt = str(txt)
|
778 |
+
|
779 |
+
plt.text(x - 0.25, y - 0.1, txt, color='black', fontsize='x-small')
|
780 |
+
|
781 |
+
ax.set_xticks(torch.linspace(0, 11, 12))
|
782 |
+
ax.set_xticklabels(transforms)
|
783 |
+
ax.set_yticks(torch.linspace(0, 11, 12))
|
784 |
+
pip.reverse()
|
785 |
+
ax.set_yticklabels(pip)
|
786 |
+
pip.reverse()
|
787 |
+
plt.xticks(rotation=45)
|
788 |
+
plt.yticks(rotation=45)
|
789 |
+
cba.set_label('Standard Deviation')
|
790 |
+
plt.xlabel("Pipelines")
|
791 |
+
plt.ylabel("Distortions")
|
792 |
+
if dataset_name == 'DroneSegmentation':
|
793 |
+
cbb.set_label('IoU')
|
794 |
+
plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_IoU.png"))
|
795 |
+
else:
|
796 |
+
cbb.set_label('Accuracy')
|
797 |
+
plt.savefig(os.path.join(path, f"{dataset_name}_{augmentation}_accuracies.png"))
|
798 |
+
|
799 |
+
|
800 |
+
if __name__ == '__main__':
|
801 |
+
|
802 |
+
if args.mode == 'ABMakeTable':
|
803 |
+
ABMakeTable(args.dataset_name, args.augmentation, args.N_runs, args.download_model)
|
804 |
+
elif args.mode == 'ABShowTable':
|
805 |
+
ABShowTable(args.dataset_name, args.augmentation)
|
806 |
+
elif args.mode == 'ABShowImages':
|
807 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
808 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
809 |
+
args.dn_test, args.N_runs, download_model=args.download_model)
|
810 |
+
ABclass.ABShowImages()
|
811 |
+
elif args.mode == 'ABShowAllImages':
|
812 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
813 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
814 |
+
args.dn_test, args.N_runs, download_model=args.download_model)
|
815 |
+
ABclass.ABShowAllImages()
|
816 |
+
elif args.mode == 'CMakeTable':
|
817 |
+
CMakeTable(args.dataset_name, args.augmentation, args.severity, args.N_runs, args.download_model)
|
818 |
+
elif args.mode == 'CShowTable': # TODO test it
|
819 |
+
CShowTable(args.dataset_name, args.augmentation, args.severity)
|
820 |
+
elif args.mode == 'CShowImages':
|
821 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
822 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
823 |
+
args.dn_test, args.N_runs, args.severity, args.transform,
|
824 |
+
download_model=args.download_model)
|
825 |
+
ABclass.CShowImages()
|
826 |
+
elif args.mode == 'CShowAllImages':
|
827 |
+
ABclass = ABtesting(args.dataset_name, args.augmentation, args.dm_train,
|
828 |
+
args.s_train, args.dn_train, args.dm_test, args.s_test,
|
829 |
+
args.dn_test, args.N_runs, args.severity, args.transform,
|
830 |
+
download_model=args.download_model)
|
831 |
+
ABclass.CShowAllImages()
|
figures/figure1.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python figures.py \
|
2 |
+
--experiment_name track-test \
|
3 |
+
--run_name track-all \
|
4 |
+
--representation gradients \
|
5 |
+
--step gamma_correct \
|
6 |
+
--gif_name gradient \
|
7 |
+
--output gif \
|
figures/figure2.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python figures.py \
|
2 |
+
--experiment_name track-test \
|
3 |
+
--run_name track-all \
|
4 |
+
--output train_vs_val_loss \
|
figures/figures.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mlflow
|
2 |
+
from mlflow.tracking import MlflowClient
|
3 |
+
from mlflow.entities import ViewType
|
4 |
+
import argparse
|
5 |
+
#gif
|
6 |
+
import os
|
7 |
+
import pathlib
|
8 |
+
import shutil
|
9 |
+
import imageio
|
10 |
+
#plot
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
# -1. parse args
|
15 |
+
parser = argparse.ArgumentParser(description="results_analysis")
|
16 |
+
parser.add_argument("--tracking_uri", type=str,
|
17 |
+
default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS')
|
18 |
+
parser.add_argument("--experiment_name", type=str, default=None,
|
19 |
+
help='Name of the experiment on the mlflow server, e.g. "processing_comparison"')
|
20 |
+
parser.add_argument("--run_name", type=str, default=None,
|
21 |
+
help='Name of the run on the mlflow server, e.g. "proc_nn"')
|
22 |
+
parser.add_argument("--representation", type=str, default=None,
|
23 |
+
choices=["processing", "gradients"], help='The representation form you want retrieve("processing" or "gradients")')
|
24 |
+
parser.add_argument("--step", type=str, default=None,
|
25 |
+
choices=["pre_debayer", "demosaic", "color_correct", "sharpening", "gaussian", "clipped", "gamma_correct", "rgb"],
|
26 |
+
help='The processing step you want to track ("pre_debayer" or "rgb")') #TODO: include predictions and ground truths
|
27 |
+
parser.add_argument("--gif_name", type=str, default=None,
|
28 |
+
help='Name of the gif that will be saved. Note: .gif will be added later by script') #TODO: option to include filepath where result should be written
|
29 |
+
#TODO: option to write results to existing run on mlflow
|
30 |
+
parser.add_argument("--local_dir", type=str, default=None,
|
31 |
+
help='Name of the local dir to be created to store mlflow data')
|
32 |
+
parser.add_argument("--cleanup", type=bool, default=True,
|
33 |
+
help='Whether to delete the local dir again after the script was run')
|
34 |
+
parser.add_argument("--output", type=str, default=None,
|
35 |
+
choices=["gif", "train_vs_val_loss"],
|
36 |
+
help='Which output to generate') #TODO: make this cleaner, atm it is confusing because each figure may need different set of args and it is not clear how to manage that
|
37 |
+
#TODO: idea -> fix the types of args for each figure which define the figure type but parametrize those things that can reasonably vary
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
# 0. mlflow basics
|
41 |
+
mlflow.set_tracking_uri(args.tracking_uri)
|
42 |
+
|
43 |
+
# 1. specify experiment_name, run_name, representation and step
|
44 |
+
#is done via parse_args
|
45 |
+
|
46 |
+
# 2. use get_experiment_by_name to get experiment object
|
47 |
+
experiment = mlflow.get_experiment_by_name(args.experiment_name)
|
48 |
+
|
49 |
+
# 3. extract experiment_id
|
50 |
+
#experiment.experiment_id
|
51 |
+
|
52 |
+
# 4. use search_runs with experiment_id and run_name for string search query
|
53 |
+
filter_string = "tags.mlflow.runName = '{}'".format(args.run_name) #create the filter string with using the runName tag to query mlflow
|
54 |
+
runs = mlflow.search_runs(experiment.experiment_id, filter_string=filter_string) #returns a pandas data frame where each row is a run (if several exist under that name)
|
55 |
+
client = MlflowClient() #TODO: look more into the options of client
|
56 |
+
|
57 |
+
if args.output == "gif": #TODO: outsource these options to functions which are then loaded and can be called
|
58 |
+
# 5. extract run from list
|
59 |
+
#TODO: parent run and cv option for analysis
|
60 |
+
if args.local_dir:
|
61 |
+
local_dir = args.local_dir+"/artifacts"
|
62 |
+
else: #use the current working dir and make a subdir "artifacts" to store the data from mlflow
|
63 |
+
local_dir = str(pathlib.Path().resolve())+"/artifacts"
|
64 |
+
if not os.path.isdir('artifacts'):
|
65 |
+
os.mkdir(local_dir) #create the local_dir if it does not exist, yet #TODO: more advanced catching of existing files etc
|
66 |
+
dir = client.download_artifacts(runs["run_id"][0], "results", local_dir) #TODO: parametrize this number [0] so the right run is selected
|
67 |
+
|
68 |
+
# 6. get filenames in chronological sequence and write them to gif
|
69 |
+
dirs = [x[0] for x in os.walk(dir)]
|
70 |
+
dirs = sorted(dirs, key=str.lower)[1:] #sort chronologically and remove parent dir from list
|
71 |
+
|
72 |
+
with imageio.get_writer(args.gif_name+'.gif', mode='I') as writer: #https://imageio.readthedocs.io/en/stable/index.html#
|
73 |
+
for epoch in dirs: #extract the right file from each epoch
|
74 |
+
for _, _, files in os.walk(epoch): #
|
75 |
+
for name in files:
|
76 |
+
if args.representation in name and args.step in name and "png" in name:
|
77 |
+
image = imageio.imread(epoch+"/"+name)
|
78 |
+
writer.append_data(image)
|
79 |
+
|
80 |
+
# 7. cleanup the downloaded artifacts from client file system
|
81 |
+
if args.cleanup:
|
82 |
+
shutil.rmtree(local_dir) #delete the files downloaded from mlflow
|
83 |
+
|
84 |
+
elif args.output == "train_vs_val_loss":
|
85 |
+
train_loss = client.get_metric_history(runs["run_id"][0], "train_loss") #returns a list of metric entities https://www.mlflow.org/docs/latest/_modules/mlflow/entities/metric.html
|
86 |
+
val_loss = client.get_metric_history(runs["run_id"][0], "val_loss") #TODO: parametrize this number [0] so the right run is selected
|
87 |
+
train_loss = sorted(train_loss, key=lambda m: m.step) #sort the metric objects in list according to step property
|
88 |
+
val_loss = sorted(val_loss, key=lambda m: m.step)
|
89 |
+
plt.figure()
|
90 |
+
for m_train, m_val in zip(train_loss, val_loss):
|
91 |
+
plt.scatter(m_train.value, m_val.value, alpha=1/(m_train.step+1), color='blue')
|
92 |
+
plt.savefig("scatter.png") #TODO: parametrize filename
|
figures/train.sh
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# # Parametrized Training
|
4 |
+
# 100 epochs, frozen_processor: http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com/#/experiments/49/runs/2803f44514e34a0f87d591520706e876
|
5 |
+
# model_uri="s3://mlflow-artifacts-601883093460/49/2803f44514e34a0f87d591520706e876/artifacts/model"
|
6 |
+
|
7 |
+
# used for training current model to 100% train and 80% val accuracy
|
8 |
+
# python train.py \
|
9 |
+
# --experiment_name parametrized \
|
10 |
+
# --classifier_uri "${model_uri}" \
|
11 |
+
# --run_name par_full_kurt \
|
12 |
+
# --dataset Microscopy \
|
13 |
+
# --lr 1e-5 \
|
14 |
+
# --epochs 50 \
|
15 |
+
# --freeze_classifier \
|
16 |
+
|
17 |
+
# --freeze_processor \
|
18 |
+
|
19 |
+
# # Adversarial Training
|
20 |
+
|
21 |
+
# python train.py \
|
22 |
+
# --experiment_name adversarial \
|
23 |
+
# --run_name adv_frozen_processor \
|
24 |
+
# --classifier_uri "${model_uri}" \
|
25 |
+
# --dataset Microscopy \
|
26 |
+
# --adv_training \
|
27 |
+
# --lr 1e-3 \
|
28 |
+
# --epochs 7 \
|
29 |
+
# --freeze_classifier \
|
30 |
+
# --track_processing \
|
31 |
+
# --track_every_epoch \
|
32 |
+
# --log_model=False \
|
33 |
+
# --adv_aux_weight=0.1 \
|
34 |
+
# --adv_aux_loss "l2" \
|
35 |
+
|
36 |
+
# --adv_aux_weight=2e-5 \
|
37 |
+
# --adv_aux_weight=2e-5 \
|
38 |
+
# --adv_aux_weight=1.9e-5 \
|
39 |
+
|
40 |
+
# Cross pipeline training (Segmentation/Classification)
|
41 |
+
|
42 |
+
# Static Pipeline Script
|
43 |
+
|
44 |
+
# datasets="Microscopy Drone DroneSegmentation"
|
45 |
+
datasets="DroneSegmentation"
|
46 |
+
augmentations="weak strong none"
|
47 |
+
|
48 |
+
demosaicings="bilinear malvar2004 menon2007"
|
49 |
+
sharpenings="sharpening_filter unsharp_masking"
|
50 |
+
denoisings="median_denoising gaussian_denoising"
|
51 |
+
|
52 |
+
for augment in $augmentations
|
53 |
+
do
|
54 |
+
for data in $datasets
|
55 |
+
do
|
56 |
+
for demosaicing in $demosaicings
|
57 |
+
do
|
58 |
+
for sharpening in $sharpenings
|
59 |
+
do
|
60 |
+
for denoising in $denoisings
|
61 |
+
do
|
62 |
+
|
63 |
+
python train.py \
|
64 |
+
--experiment_name ABtesting \
|
65 |
+
--run_name "$data"_"$demosaicing"_"$sharpening"_"$denoising"_"$augment" \
|
66 |
+
--dataset "$data" \
|
67 |
+
--batch_size 4 \
|
68 |
+
--lr 1e-5 \
|
69 |
+
--epochs 100 \
|
70 |
+
--sp_debayer "$demosaicing" \
|
71 |
+
--sp_sharpening "$sharpening" \
|
72 |
+
--sp_denoising "$denoising" \
|
73 |
+
--processing_mode "static" \
|
74 |
+
--augmentation "$augment" \
|
75 |
+
--n_split 5 \
|
76 |
+
|
77 |
+
done
|
78 |
+
done
|
79 |
+
done
|
80 |
+
done
|
81 |
+
done
|
model.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.optim
|
6 |
+
from torchvision.models import resnet18
|
7 |
+
from torchvision.utils import make_grid, save_image
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
|
12 |
+
import mlflow.pytorch
|
13 |
+
|
14 |
+
|
15 |
+
def resnet_model(model=resnet18, pretrained=True, in_channels=3, fc_out_features=2):
|
16 |
+
resnet = model(pretrained=pretrained)
|
17 |
+
# if not pretrained: # TODO: add case for in_channels=4
|
18 |
+
# resnet.conv1 = torch.nn.Conv2d(channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
19 |
+
resnet.fc = torch.nn.Linear(in_features=512, out_features=fc_out_features, bias=True)
|
20 |
+
return resnet
|
21 |
+
|
22 |
+
|
23 |
+
class LitModel(pl.LightningModule):
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
classifier,
|
27 |
+
loss,
|
28 |
+
lr=1e-3,
|
29 |
+
weight_decay=0,
|
30 |
+
loss_aux=None,
|
31 |
+
adv_training=False,
|
32 |
+
adv_parameters='all',
|
33 |
+
metrics=None,
|
34 |
+
processor=None,
|
35 |
+
augmentation=None,
|
36 |
+
is_segmentation_task=False,
|
37 |
+
augmentation_on_eval=False,
|
38 |
+
metrics_on_training=True,
|
39 |
+
freeze_classifier=False,
|
40 |
+
freeze_processor=False,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
self.classifier = classifier
|
45 |
+
self.processor = processor
|
46 |
+
|
47 |
+
self.lr = lr
|
48 |
+
self.weight_decay = weight_decay
|
49 |
+
self.loss_fn = loss
|
50 |
+
self.loss_aux_fn = loss_aux
|
51 |
+
self.adv_training = adv_training
|
52 |
+
self.metrics = metrics
|
53 |
+
self.augmentation = augmentation
|
54 |
+
self.is_segmentation_task = is_segmentation_task
|
55 |
+
self.augmentation_on_eval = augmentation_on_eval
|
56 |
+
self.metrics_on_training = metrics_on_training
|
57 |
+
|
58 |
+
self.freeze_classifier = freeze_classifier
|
59 |
+
self.freeze_processor = freeze_processor
|
60 |
+
|
61 |
+
self.unfreeze()
|
62 |
+
if freeze_classifier:
|
63 |
+
pl.LightningModule.freeze(self.classifier)
|
64 |
+
if freeze_processor:
|
65 |
+
pl.LightningModule.freeze(self.processor)
|
66 |
+
|
67 |
+
if adv_training and adv_parameters != 'all':
|
68 |
+
if adv_parameters != 'all':
|
69 |
+
pl.LightningModule.freeze(self.processor)
|
70 |
+
for name, p in self.processor.named_parameters():
|
71 |
+
if adv_parameters in name:
|
72 |
+
p.requires_grad = True
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
x = self.processor(x)
|
76 |
+
apply_augmentation_step = self.training or self.augmentation_on_eval
|
77 |
+
if self.augmentation is not None and apply_augmentation_step:
|
78 |
+
x = self.augmentation(x, retain_state=self.is_segmentation_task)
|
79 |
+
x = self.classifier(x)
|
80 |
+
return x
|
81 |
+
|
82 |
+
def update_step(self, batch, step_name):
|
83 |
+
x, y = batch
|
84 |
+
# debug(self.processor)
|
85 |
+
# debug(self.processor.parameters())
|
86 |
+
# debug.pause()
|
87 |
+
# print('type', type(self.processor).__name__)
|
88 |
+
|
89 |
+
logits = self(x)
|
90 |
+
|
91 |
+
apply_augmentation_mask = self.is_segmentation_task and (self.training or self.augmentation_on_eval)
|
92 |
+
if self.augmentation is not None and apply_augmentation_mask:
|
93 |
+
y = self.augmentation(y, mask_transform=True).contiguous()
|
94 |
+
|
95 |
+
loss = self.loss_fn(logits, y)
|
96 |
+
|
97 |
+
if self.loss_aux_fn is not None:
|
98 |
+
loss_aux = self.loss_aux_fn(x)
|
99 |
+
loss += loss_aux
|
100 |
+
|
101 |
+
self.log(f'{step_name}_loss', loss, on_step=False, on_epoch=True)
|
102 |
+
if self.loss_aux_fn is not None:
|
103 |
+
self.log(f'{step_name}_loss_aux', loss_aux, on_step=False, on_epoch=True)
|
104 |
+
|
105 |
+
if self.is_segmentation_task:
|
106 |
+
y_hat = F.logsigmoid(logits).exp().squeeze()
|
107 |
+
else:
|
108 |
+
y_hat = torch.argmax(logits, dim=1)
|
109 |
+
|
110 |
+
if self.metrics is not None:
|
111 |
+
for metric in self.metrics:
|
112 |
+
metric_name = metric.__name__ if hasattr(metric, '__name__') else type(metric).__name__
|
113 |
+
if metric_name == 'accuracy' or not self.training or self.metrics_on_training:
|
114 |
+
m = metric(y_hat.cpu().detach(), y.cpu())
|
115 |
+
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
|
116 |
+
prog_bar=self.training or metric_name == 'accuracy')
|
117 |
+
if metric_name == 'iou_score' or not self.training or self.metrics_on_training:
|
118 |
+
m = metric(y_hat.cpu().detach(), y.cpu())
|
119 |
+
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
|
120 |
+
prog_bar=self.training or metric_name == 'iou_score')
|
121 |
+
elif metric_name == 'accuracy' or not self.training or self.metrics_on_training:
|
122 |
+
m = metric(y_hat.cpu().detach(), y.cpu())
|
123 |
+
self.log(f'{step_name}_{metric_name}', m, on_step=False, on_epoch=True,
|
124 |
+
prog_bar=self.training or metric_name == 'accuracy')
|
125 |
+
|
126 |
+
return loss
|
127 |
+
|
128 |
+
def training_step(self, batch, batch_idx):
|
129 |
+
return self.update_step(batch, 'train')
|
130 |
+
|
131 |
+
def validation_step(self, batch, batch_idx):
|
132 |
+
return self.update_step(batch, 'val')
|
133 |
+
|
134 |
+
def test_step(self, batch, batch_idx):
|
135 |
+
return self.update_step(batch, 'test')
|
136 |
+
|
137 |
+
def train(self, mode=True):
|
138 |
+
self.training = mode
|
139 |
+
|
140 |
+
# don't update batchnorm in adversarial training
|
141 |
+
self.processor.train(mode=mode and not self.freeze_processor and not self.adv_training)
|
142 |
+
self.classifier.train(mode=mode and not self.freeze_classifier)
|
143 |
+
return self
|
144 |
+
|
145 |
+
def configure_optimizers(self):
|
146 |
+
self.optimizer = torch.optim.Adam(self.parameters(), self.lr, weight_decay=self.weight_decay)
|
147 |
+
return self.optimizer
|
148 |
+
|
149 |
+
def get_progress_bar_dict(self):
|
150 |
+
items = super().get_progress_bar_dict()
|
151 |
+
items.pop('v_num')
|
152 |
+
return items
|
153 |
+
|
154 |
+
|
155 |
+
class TrackImagesCallback(pl.callbacks.base.Callback):
|
156 |
+
def __init__(self, data_loader, reference_processor=None, track_every_epoch=False, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True):
|
157 |
+
super().__init__()
|
158 |
+
self.data_loader = data_loader
|
159 |
+
|
160 |
+
self.track_every_epoch = track_every_epoch
|
161 |
+
|
162 |
+
self.track_processing = track_processing
|
163 |
+
self.track_gradients = track_gradients
|
164 |
+
self.track_predictions = track_predictions
|
165 |
+
self.save_tensors = save_tensors
|
166 |
+
|
167 |
+
self.reference_processor = reference_processor
|
168 |
+
|
169 |
+
def callback_track_images(self, model, save_loc):
|
170 |
+
track_images(model,
|
171 |
+
self.data_loader,
|
172 |
+
reference_processor=self.reference_processor,
|
173 |
+
track_processing=self.track_processing,
|
174 |
+
track_gradients=self.track_gradients,
|
175 |
+
track_predictions=self.track_predictions,
|
176 |
+
save_tensors=self.save_tensors,
|
177 |
+
save_loc=save_loc,
|
178 |
+
)
|
179 |
+
|
180 |
+
def on_fit_end(self, trainer, pl_module):
|
181 |
+
if not self.track_every_epoch:
|
182 |
+
save_loc = 'results'
|
183 |
+
self.callback_track_images(trainer.model, save_loc)
|
184 |
+
|
185 |
+
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
186 |
+
if self.track_every_epoch:
|
187 |
+
save_loc = f'results/epoch_{trainer.current_epoch + 1:04d}'
|
188 |
+
self.callback_track_images(trainer.model, save_loc)
|
189 |
+
|
190 |
+
|
191 |
+
from utils.debug import debug
|
192 |
+
|
193 |
+
|
194 |
+
# @debug
|
195 |
+
def log_tensor(batch, path, save_tensors=True, nrow=8):
|
196 |
+
if save_tensors:
|
197 |
+
torch.save(batch, path)
|
198 |
+
mlflow.log_artifact(path, os.path.dirname(path))
|
199 |
+
|
200 |
+
img_path = path.replace('.pt', '.png')
|
201 |
+
split = img_path.split('/')
|
202 |
+
img_path = '/'.join(split[:-1]) + '/img_' + split[-1] # insert 'img_'; make it easier to find in mlflow
|
203 |
+
|
204 |
+
grid = make_grid(batch, nrow=nrow).squeeze()
|
205 |
+
save_image(grid, img_path)
|
206 |
+
mlflow.log_artifact(img_path, os.path.dirname(path))
|
207 |
+
|
208 |
+
|
209 |
+
def track_images(model, data_loader, reference_processor=None, track_processing=True, track_gradients=True, track_predictions=True, save_tensors=True, save_loc='results'):
|
210 |
+
|
211 |
+
device = model.device
|
212 |
+
processor = model.processor
|
213 |
+
classifier = model.classifier
|
214 |
+
|
215 |
+
if not hasattr(processor, 'stages'): # 'static' or 'none' pipeline
|
216 |
+
return
|
217 |
+
|
218 |
+
os.makedirs(save_loc, exist_ok=True)
|
219 |
+
|
220 |
+
# TODO: implement track_predictions
|
221 |
+
|
222 |
+
# inputs_full = []
|
223 |
+
labels_full = []
|
224 |
+
logits_full = []
|
225 |
+
stages_full = defaultdict(list)
|
226 |
+
grads_full = defaultdict(list)
|
227 |
+
diffs_full = defaultdict(list)
|
228 |
+
|
229 |
+
track_differences = reference_processor is not None
|
230 |
+
|
231 |
+
for inputs, labels in data_loader:
|
232 |
+
|
233 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
234 |
+
inputs.requires_grad = True
|
235 |
+
|
236 |
+
processed_rgb = processor(inputs)
|
237 |
+
|
238 |
+
if track_differences:
|
239 |
+
# debug(processor)
|
240 |
+
processed_rgb_ref = reference_processor(inputs)
|
241 |
+
|
242 |
+
if track_gradients or track_predictions:
|
243 |
+
logits = classifier(processed_rgb)
|
244 |
+
|
245 |
+
# NOTE: should zero grads for good measure
|
246 |
+
loss = model.loss_fn(logits, labels)
|
247 |
+
loss.backward()
|
248 |
+
|
249 |
+
if track_predictions:
|
250 |
+
labels_full.append(labels.cpu().detach())
|
251 |
+
logits_full.append(logits.cpu().detach())
|
252 |
+
# inputs_full.append(inputs.cpu().detach())
|
253 |
+
|
254 |
+
for stage, batch in processor.stages.items():
|
255 |
+
stages_full[stage].append(batch.cpu().detach())
|
256 |
+
if track_differences:
|
257 |
+
diffs_full[stage].append((reference_processor.stages[stage] - batch).cpu().detach())
|
258 |
+
if track_gradients:
|
259 |
+
grads_full[stage].append(batch.grad.cpu().detach())
|
260 |
+
|
261 |
+
with torch.no_grad():
|
262 |
+
|
263 |
+
stages = stages_full
|
264 |
+
grads = grads_full
|
265 |
+
diffs = diffs_full
|
266 |
+
|
267 |
+
if track_processing:
|
268 |
+
for stage, batch in stages.items():
|
269 |
+
stages[stage] = torch.cat(batch)
|
270 |
+
|
271 |
+
if track_differences:
|
272 |
+
for stage, batch in diffs.items():
|
273 |
+
diffs[stage] = torch.cat(batch)
|
274 |
+
|
275 |
+
if track_gradients:
|
276 |
+
for stage, batch in grads.items():
|
277 |
+
grads[stage] = torch.cat(batch)
|
278 |
+
|
279 |
+
for stage_nr, stage_name in enumerate(stages):
|
280 |
+
if track_processing:
|
281 |
+
batch = stages[stage_name]
|
282 |
+
log_tensor(batch, os.path.join(save_loc, f'processing_{stage_nr}_{stage_name}.pt'), save_tensors)
|
283 |
+
|
284 |
+
if track_differences:
|
285 |
+
batch = diffs[stage_name]
|
286 |
+
log_tensor(batch, os.path.join(save_loc, f'diffs_{stage_nr}_{stage_name}.pt'), False)
|
287 |
+
|
288 |
+
if track_gradients:
|
289 |
+
batch_grad = grads[stage_name]
|
290 |
+
batch_grad = batch_grad.abs()
|
291 |
+
batch_grad = (batch_grad - batch_grad.min()) / (batch_grad.max() - batch_grad.min())
|
292 |
+
log_tensor(batch_grad, os.path.join(
|
293 |
+
save_loc, f'gradients_{stage_nr}_{stage_name}.pt'), save_tensors)
|
294 |
+
|
295 |
+
# inputs = torch.cat(inputs_full)
|
296 |
+
|
297 |
+
if track_predictions: # and model.is_segmentation_task:
|
298 |
+
labels = torch.cat(labels_full)
|
299 |
+
logits = torch.cat(logits_full)
|
300 |
+
masks = labels.unsqueeze(1)
|
301 |
+
predictions = logits # torch.sigmoid(logits).unsqueeze(1)
|
302 |
+
#mask_vis = torch.cat((masks, predictions, masks * predictions), dim=1)
|
303 |
+
#log_tensor(mask_vis, os.path.join(save_loc, f'masks.pt'), save_tensors)
|
304 |
+
log_tensor(masks, os.path.join(save_loc, f'targets.pt'), save_tensors)
|
305 |
+
log_tensor(predictions, os.path.join(save_loc, f'preds.pt'), save_tensors)
|
processing/pipeline_numpy.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Raw Image Pipeline
|
3 |
+
"""
|
4 |
+
__author__ = "Marco Aversa"
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from rawpy import * # XXX: no * imports!
|
9 |
+
from scipy import ndimage
|
10 |
+
from scipy import fftpack
|
11 |
+
from scipy.signal import convolve2d
|
12 |
+
|
13 |
+
from skimage.filters import unsharp_mask
|
14 |
+
from skimage.color import rgb2yuv, yuv2rgb, rgb2hsv, hsv2rgb
|
15 |
+
from skimage.restoration import denoise_tv_chambolle, denoise_tv_bregman, denoise_nl_means, denoise_bilateral, denoise_wavelet, estimate_sigma
|
16 |
+
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
|
20 |
+
demosaicing_CFA_Bayer_Malvar2004,
|
21 |
+
demosaicing_CFA_Bayer_Menon2007)
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
from dataset import Subset
|
27 |
+
from torch.utils.data import DataLoader
|
28 |
+
|
29 |
+
from colour_demosaicing import (demosaicing_CFA_Bayer_bilinear,
|
30 |
+
demosaicing_CFA_Bayer_Malvar2004,
|
31 |
+
demosaicing_CFA_Bayer_Menon2007)
|
32 |
+
|
33 |
+
import matplotlib.pyplot as plt
|
34 |
+
|
35 |
+
|
36 |
+
class RawProcessingPipeline(object):
|
37 |
+
|
38 |
+
"""Applies the raw-processing pipeline from pipeline.py"""
|
39 |
+
|
40 |
+
def __init__(self, camera_parameters, debayer='bilinear', sharpening='unsharp_masking', denoising='gaussian'):
|
41 |
+
'''
|
42 |
+
Args:
|
43 |
+
camera_parameters (tuple): (black_level, white_balance, colour_matrix)
|
44 |
+
debayer (str): specifies the algorithm used as debayer; choose from {'bilinear','malvar2004','menon2007'}
|
45 |
+
sharpening (str): specifies the algorithm used for sharpening; choose from {'sharpening_filter','unsharp_masking'}
|
46 |
+
denoising (str): specifies the algorithm used for denoising; choose from choose from {'gaussian_denoising','median_denoising','fft_denoising'}
|
47 |
+
'''
|
48 |
+
|
49 |
+
self.camera_parameters = camera_parameters
|
50 |
+
|
51 |
+
self.debayer = debayer
|
52 |
+
self.sharpening = sharpening
|
53 |
+
self.denoising = denoising
|
54 |
+
|
55 |
+
def __call__(self, img):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
img (ndarry of dtype float.32): image of size (H,W)
|
59 |
+
return:
|
60 |
+
img (tensor of dtype float): image of size (3,H,W)
|
61 |
+
"""
|
62 |
+
black_level, white_balance, colour_matrix = self.camera_parameters
|
63 |
+
img = processing(img, black_level, white_balance, colour_matrix,
|
64 |
+
debayer=self.debayer, sharpening=self.sharpening, denoising=self.denoising)
|
65 |
+
img = img.transpose(2, 0, 1)
|
66 |
+
|
67 |
+
return torch.Tensor(img)
|
68 |
+
|
69 |
+
|
70 |
+
def processing(img, black_level, white_balance, colour_matrix, debayer="bilinear", sharpening="unsharp_masking",
|
71 |
+
sharp_radius=1.0, sharp_amount=1.0, denoising="median_filter", median_kernel_size=3,
|
72 |
+
gaussian_sigma=0.5, fft_fraction=0.3, weight_chambolle=0.01, weight_bregman=100,
|
73 |
+
sigma_bilateral=0.6, gamma=2.2, bits=16):
|
74 |
+
"""Apply pipeline on a raw image
|
75 |
+
|
76 |
+
Args:
|
77 |
+
rawImg (ndarray): raw image
|
78 |
+
debayer (str): debayer algorithm
|
79 |
+
white_balance (None, ndarray): white balance array (if None it will take the default camera white balance array)
|
80 |
+
colour_matrix (None, ndarray): colour matrix (if None it will take the default camera colour matrix) - Size: 3x3
|
81 |
+
gamma (float): exponent for the non linear gamma correction.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
img (ndarray): post-processed image
|
85 |
+
|
86 |
+
"""
|
87 |
+
|
88 |
+
# Remove Black Level
|
89 |
+
img = remove_blacklv(img, black_level)
|
90 |
+
|
91 |
+
# Apply demosaicing - We don't have access to these 3 functions
|
92 |
+
if debayer == "bilinear":
|
93 |
+
img = demosaicing_CFA_Bayer_bilinear(img)
|
94 |
+
if debayer == "malvar2004":
|
95 |
+
img = demosaicing_CFA_Bayer_Malvar2004(img)
|
96 |
+
if debayer == "menon2007":
|
97 |
+
img = demosaicing_CFA_Bayer_Menon2007(img)
|
98 |
+
|
99 |
+
# White Balance Correction
|
100 |
+
|
101 |
+
# Sunny images white balance array -> 2<r<2.8, g=1.0, 1.3<b<1.6
|
102 |
+
# Tungsten images white balance array -> 1.3<r<1.7, g=1.0, 2.2<b<2.8
|
103 |
+
# Shade images white balance array -> 2.4<r<3.2, g=1.0, 1.1<b<1.3
|
104 |
+
|
105 |
+
img = wb_correction(img, white_balance)
|
106 |
+
|
107 |
+
# Colour Correction
|
108 |
+
img = colour_correction(img, colour_matrix)
|
109 |
+
|
110 |
+
# Sharpening
|
111 |
+
if sharpening == "sharpening_filter": # Fixed sharpening
|
112 |
+
img = sharpening_filter(img)
|
113 |
+
if sharpening == "unsharp_masking": # Higher is radius and amount, higher is the sharpening
|
114 |
+
img = unsharp_masking(img, radius=sharp_radius, amount=sharp_amount, multichannel=True)
|
115 |
+
|
116 |
+
# Denoising
|
117 |
+
if denoising == "median_denoising":
|
118 |
+
img = median_denoising(img, size=median_kernel_size)
|
119 |
+
if denoising == "gaussian_denoising":
|
120 |
+
img = gaussian_denoising(img, sigma=gaussian_sigma)
|
121 |
+
if denoising == "fft_denoising": # fft_fraction = [0.0001,0.5]
|
122 |
+
img = fft_denoising(img, keep_fraction=fft_fraction, row_cut=False, column_cut=True)
|
123 |
+
|
124 |
+
# We don't have access to these 3 functions
|
125 |
+
if denoising == "tv_chambolle": # lower is weight, less is the denoising
|
126 |
+
img = denoise_tv_chambolle(img, weight=weight_chambolle, eps=0.0002, n_iter_max=200, multichannel=True)
|
127 |
+
if denoising == "tv_bregman": # lower is weight, more is the denoising
|
128 |
+
img = denoise_tv_bregman(img, weight=weight_bregman, max_iter=100,
|
129 |
+
eps=0.001, isotropic=True, multichannel=True)
|
130 |
+
# if denoising == "wavelet":
|
131 |
+
# img = denoise_wavelet(img.copy(), sigma=None, wavelet='db1', mode='soft', wavelet_levels=None, multichannel=True,
|
132 |
+
# convert2ycbcr=False, method='BayesShrink', rescale_sigma=True)
|
133 |
+
if denoising == "bilateral": # higher is sigma_spatial, more is the denoising
|
134 |
+
img = denoise_bilateral(img, win_size=None, sigma_color=None, sigma_spatial=sigma_bilateral,
|
135 |
+
bins=10000, mode='constant', cval=0, multichannel=True)
|
136 |
+
|
137 |
+
# Gamma Correction
|
138 |
+
img = np.clip(img, 0, 1)
|
139 |
+
img = adjust_gamma(img, gamma=gamma)
|
140 |
+
|
141 |
+
return img
|
142 |
+
|
143 |
+
|
144 |
+
def get_camera_parameters(rawpyImg):
|
145 |
+
black_level = rawpyImg.black_level_per_channel
|
146 |
+
white_balance = rawpyImg.camera_whitebalance[:3]
|
147 |
+
colour_matrix = rawpyImg.color_matrix[:, :3].flatten().tolist()
|
148 |
+
|
149 |
+
return black_level, white_balance, colour_matrix
|
150 |
+
|
151 |
+
|
152 |
+
def remove_blacklv(rawImg, black_level):
|
153 |
+
rawImg[0::2, 0::2] -= black_level[0] # R
|
154 |
+
rawImg[0::2, 1::2] -= black_level[1] # G
|
155 |
+
rawImg[1::2, 0::2] -= black_level[2] # G
|
156 |
+
rawImg[1::2, 1::2] -= black_level[3] # B
|
157 |
+
|
158 |
+
return rawImg
|
159 |
+
|
160 |
+
|
161 |
+
def wb_correction(img, white_balance):
|
162 |
+
return img * white_balance
|
163 |
+
|
164 |
+
|
165 |
+
def colour_correction(img, colour_matrix):
|
166 |
+
colour_matrix = np.array(colour_matrix).reshape(3, 3)
|
167 |
+
return np.einsum('ijk,lk->ijl', img, colour_matrix)
|
168 |
+
|
169 |
+
|
170 |
+
def unsharp_masking(img, radius=1.0, amount=1.0,
|
171 |
+
multichannel=False, preserve_range=True):
|
172 |
+
|
173 |
+
img = rgb2yuv(img)
|
174 |
+
img[:, :, 0] = unsharp_mask(img[:, :, 0], radius=radius, amount=amount,
|
175 |
+
multichannel=multichannel, preserve_range=preserve_range)
|
176 |
+
img = yuv2rgb(img)
|
177 |
+
return img
|
178 |
+
|
179 |
+
|
180 |
+
def sharpening_filter(image, iterations=1, kernel=np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])):
|
181 |
+
|
182 |
+
# https://towardsdatascience.com/image-processing-with-python-blurring-and-sharpening-for-beginners-3bcebec0583a
|
183 |
+
|
184 |
+
img_yuv = rgb2yuv(image)
|
185 |
+
|
186 |
+
for i in range(iterations):
|
187 |
+
img_yuv[:, :, 0] = convolve2d(img_yuv[:, :, 0], kernel, 'same', boundary='fill', fillvalue=0)
|
188 |
+
|
189 |
+
final_image = yuv2rgb(img_yuv)
|
190 |
+
|
191 |
+
return final_image
|
192 |
+
|
193 |
+
|
194 |
+
def median_denoising(img, size=3):
|
195 |
+
|
196 |
+
img = rgb2yuv(img)
|
197 |
+
img[:, :, 0] = ndimage.median_filter(img[:, :, 0], size)
|
198 |
+
img = yuv2rgb(img)
|
199 |
+
|
200 |
+
return img
|
201 |
+
|
202 |
+
|
203 |
+
def gaussian_denoising(img, sigma=0.5):
|
204 |
+
|
205 |
+
img = rgb2yuv(img)
|
206 |
+
img[:, :, 0] = ndimage.gaussian_filter(img[:, :, 0], sigma)
|
207 |
+
img = yuv2rgb(img)
|
208 |
+
|
209 |
+
return img
|
210 |
+
|
211 |
+
|
212 |
+
def fft_denoising(img, keep_fraction=0.3, row_cut=False, column_cut=True):
|
213 |
+
""" keep_fraction = 0.5 --> same image as input
|
214 |
+
keep_fraction --> 0 --> remove all details """
|
215 |
+
# http://scipy-lectures.org/intro/scipy/auto_examples/solutions/plot_fft_image_denoise.html
|
216 |
+
|
217 |
+
im_fft = fftpack.fft2(img)
|
218 |
+
|
219 |
+
# Call ff a copy of the original transform. Numpy arrays have a copy
|
220 |
+
# method for this purpose.
|
221 |
+
im_fft2 = im_fft
|
222 |
+
|
223 |
+
# Set r and c to be the number of rows and columns of the array.
|
224 |
+
r, c, _ = im_fft2.shape
|
225 |
+
|
226 |
+
# Set to zero all rows with indices between r*keep_fraction and r*(1-keep_fraction):
|
227 |
+
if row_cut == True:
|
228 |
+
im_fft2[int(r * keep_fraction):int(r * (1 - keep_fraction))] = 0
|
229 |
+
|
230 |
+
# Similarly with the columns:
|
231 |
+
if column_cut == True:
|
232 |
+
im_fft2[:, int(c * keep_fraction):int(c * (1 - keep_fraction))] = 0
|
233 |
+
|
234 |
+
# Reconstruct the denoised image from the filtered spectrum, keep only the
|
235 |
+
# real part for display.
|
236 |
+
im_new = fftpack.ifft2(im_fft2).real
|
237 |
+
|
238 |
+
return im_new
|
239 |
+
|
240 |
+
|
241 |
+
def adjust_gamma(img, gamma=1.0):
|
242 |
+
invGamma = 1.0 / gamma
|
243 |
+
img = (img ** invGamma)
|
244 |
+
return img
|
245 |
+
|
246 |
+
|
247 |
+
def show_img(img, title="no_title", size=12, histo=True, bins=300, bits=16, x_range=-1):
|
248 |
+
"""Plot image and its histogram
|
249 |
+
|
250 |
+
Args:
|
251 |
+
img (ndarray): image to plot
|
252 |
+
title (str): title of the plot
|
253 |
+
histo (bool): True - Plot histrograms per channel of the image. False - Plot the curve of histogram in a continue way
|
254 |
+
bins (int): number of bins of the histogram
|
255 |
+
size (int): figure size
|
256 |
+
bits (int): number of bits per pixel in the ndarray
|
257 |
+
x_range (list): maximum x range of the histogram (if -1 it will be take all x values)
|
258 |
+
"""
|
259 |
+
shape = img.shape
|
260 |
+
|
261 |
+
fig = plt.figure(figsize=(size, size))
|
262 |
+
|
263 |
+
# show original image
|
264 |
+
fig.add_subplot(221)
|
265 |
+
if len(shape) > 2 and img.max() > 255:
|
266 |
+
img_to_show = (img.copy() * 255. / (2**bits - 1)).astype(int)
|
267 |
+
else:
|
268 |
+
img_to_show = img.copy().astype(int)
|
269 |
+
plt.imshow(img_to_show)
|
270 |
+
if title != "no_title":
|
271 |
+
plt.title(title)
|
272 |
+
|
273 |
+
fig.add_subplot(222)
|
274 |
+
|
275 |
+
if len(shape) > 2:
|
276 |
+
if histo == True:
|
277 |
+
plt.hist(img[:, :, 0].flatten(), bins=bins, label="Channel1", color="red", alpha=0.5)
|
278 |
+
plt.hist(img[:, :, 1].flatten(), bins=bins, label="Channel2", color="green", alpha=0.5)
|
279 |
+
plt.hist(img[:, :, 2].flatten(), bins=bins, label="Channel3", color="blue", alpha=0.5)
|
280 |
+
if x_range != -1:
|
281 |
+
plt.xlim([x_range[0], x_range[1]])
|
282 |
+
else:
|
283 |
+
h1, b1 = np.histogram(img[:, :, 0].flatten(), bins=bins)
|
284 |
+
h2, b2 = np.histogram(img[:, :, 1].flatten(), bins=bins)
|
285 |
+
h3, b3 = np.histogram(img[:, :, 2].flatten(), bins=bins)
|
286 |
+
plt.plot(b1[:-1], h1, label="Channel1", color="red", alpha=0.5)
|
287 |
+
plt.plot(b2[:-1], h2, label="Channel2", color="green", alpha=0.5)
|
288 |
+
plt.plot(b3[:-1], h3, label="Channel3", color="blue", alpha=0.5)
|
289 |
+
|
290 |
+
plt.legend()
|
291 |
+
else:
|
292 |
+
if histo == True:
|
293 |
+
plt.hist(img.flatten(), bins=bins)
|
294 |
+
if x_range != -1:
|
295 |
+
plt.xlim([x_range[0], x_range[1]])
|
296 |
+
else:
|
297 |
+
h, b = np.histogram(img.flatten(), bins=bins)
|
298 |
+
plt.plot(b[:-1], h)
|
299 |
+
|
300 |
+
plt.xlabel("Intensities")
|
301 |
+
plt.ylabel("Counts")
|
302 |
+
|
303 |
+
plt.show()
|
304 |
+
|
305 |
+
|
306 |
+
def get_statistics(dataset, train_indices, transform=None):
|
307 |
+
"""Calculates the mean and the standard deviation of a given sub train set of dataset
|
308 |
+
|
309 |
+
Args:
|
310 |
+
dataset (Subset of DroneDataset):
|
311 |
+
train_indices (tensor): indicies correponding to a subset of the dataset
|
312 |
+
transform (Compose): list of transformations compatible with Compose to be applied before calculations
|
313 |
+
return:
|
314 |
+
mean (tensor of dtype float): size (C,1,1)
|
315 |
+
std (tensor of dtype float): size (C,1,1)
|
316 |
+
"""
|
317 |
+
|
318 |
+
trainset = Subset(dataset, indices=train_indices, transform=transform)
|
319 |
+
dataloader = DataLoader(trainset, batch_size=len(trainset), shuffle=False)
|
320 |
+
dataiter = iter(dataloader)
|
321 |
+
|
322 |
+
images, labels = dataiter.next()
|
323 |
+
|
324 |
+
if len(images.shape) == 3:
|
325 |
+
mean, std = torch.mean(images, axis=(0, 1, 2)), torch.std(images, axis=(0, 1, 2))
|
326 |
+
return mean, std
|
327 |
+
else:
|
328 |
+
mean, std = torch.mean(images, axis=(0, 2, 3))[:, None, None], torch.std(images, axis=(0, 2, 3))[:, None, None]
|
329 |
+
return mean, std
|
processing/pipeline_torch.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from numpy.lib.function_base import interp
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
if not os.path.exists('README.md'):
|
6 |
+
os.chdir('..')
|
7 |
+
|
8 |
+
from processing.pipeline_numpy import processing as default_processing
|
9 |
+
from utils.base import np2torch, torch2np
|
10 |
+
|
11 |
+
import segmentation_models_pytorch as smp
|
12 |
+
|
13 |
+
from utils.debug import debug
|
14 |
+
|
15 |
+
K_G = torch.Tensor([[0, 1, 0],
|
16 |
+
[1, 4, 1],
|
17 |
+
[0, 1, 0]]) / 4
|
18 |
+
|
19 |
+
K_RB = torch.Tensor([[1, 2, 1],
|
20 |
+
[2, 4, 2],
|
21 |
+
[1, 2, 1]]) / 4
|
22 |
+
|
23 |
+
M_RGB_2_YUV = torch.Tensor([[0.299, 0.587, 0.114],
|
24 |
+
[-0.14714119, -0.28886916, 0.43601035],
|
25 |
+
[0.61497538, -0.51496512, -0.10001026]])
|
26 |
+
M_YUV_2_RGB = torch.Tensor([[1.0000000000e+00, -4.1827794561e-09, 1.1398830414e+00],
|
27 |
+
[1.0000000000e+00, -3.9464232326e-01, -5.8062183857e-01],
|
28 |
+
[1.0000000000e+00, 2.0320618153e+00, -1.2232658220e-09]])
|
29 |
+
|
30 |
+
K_BLUR = torch.Tensor([[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08],
|
31 |
+
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
|
32 |
+
[2.0755e-04, 8.3731e-02, 6.1869e-01, 8.3731e-02, 2.0755e-04],
|
33 |
+
[2.8089e-05, 1.1332e-02, 8.3731e-02, 1.1332e-02, 2.8089e-05],
|
34 |
+
[6.9625e-08, 2.8089e-05, 2.0755e-04, 2.8089e-05, 6.9625e-08]])
|
35 |
+
K_SHARP = torch.Tensor([[0, -1, 0],
|
36 |
+
[-1, 5, -1],
|
37 |
+
[0, -1, 0]])
|
38 |
+
DEFAULT_CAMERA_PARAMS = (
|
39 |
+
[0., 0., 0., 0.],
|
40 |
+
[1., 1., 1.],
|
41 |
+
[1., 0., 0., 0., 1., 0., 0., 0., 1.],
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
class RawToRGB(nn.Module):
|
46 |
+
def __init__(self, reduce_size=True, out_channels=3, track_stages=False, normalize_mosaic=None):
|
47 |
+
super().__init__()
|
48 |
+
self.stages = None
|
49 |
+
self.buffer = None
|
50 |
+
self.reduce_size = reduce_size
|
51 |
+
self.out_channels = out_channels
|
52 |
+
self.track_stages = track_stages
|
53 |
+
self.normalize_mosaic = normalize_mosaic
|
54 |
+
|
55 |
+
def forward(self, raw):
|
56 |
+
self.stages = {}
|
57 |
+
self.buffer = {}
|
58 |
+
|
59 |
+
rgb = raw2rgb(raw, reduce_size=self.reduce_size, out_channels=self.out_channels)
|
60 |
+
self.stages['demosaic'] = rgb
|
61 |
+
if self.normalize_mosaic:
|
62 |
+
rgb = self.normalize_mosaic(rgb)
|
63 |
+
|
64 |
+
if self.track_stages and raw.requires_grad:
|
65 |
+
for stage in self.stages.values():
|
66 |
+
stage.retain_grad()
|
67 |
+
|
68 |
+
self.buffer['processed_rgb'] = rgb
|
69 |
+
|
70 |
+
return rgb
|
71 |
+
|
72 |
+
|
73 |
+
class NNProcessing(nn.Module):
|
74 |
+
def __init__(self, track_stages=False, normalize_mosaic=None, batch_norm_output=True):
|
75 |
+
super().__init__()
|
76 |
+
self.stages = None
|
77 |
+
self.buffer = None
|
78 |
+
self.track_stages = track_stages
|
79 |
+
self.model = smp.UnetPlusPlus(
|
80 |
+
encoder_name='resnet34',
|
81 |
+
encoder_depth=3,
|
82 |
+
decoder_channels=[256, 128, 64],
|
83 |
+
in_channels=3,
|
84 |
+
classes=3,
|
85 |
+
)
|
86 |
+
self.batch_norm = None if not batch_norm_output else nn.BatchNorm2d(3, affine=False)
|
87 |
+
self.normalize_mosaic = normalize_mosaic
|
88 |
+
|
89 |
+
def forward(self, raw):
|
90 |
+
self.stages = {}
|
91 |
+
self.buffer = {}
|
92 |
+
# self.stages['raw'] = raw
|
93 |
+
rgb = raw2rgb(raw)
|
94 |
+
if self.normalize_mosaic:
|
95 |
+
rgb = self.normalize_mosaic(rgb)
|
96 |
+
self.stages['demosaic'] = rgb
|
97 |
+
rgb = self.model(rgb)
|
98 |
+
if self.batch_norm is not None:
|
99 |
+
rgb = self.batch_norm(rgb)
|
100 |
+
self.stages['rgb'] = rgb
|
101 |
+
|
102 |
+
if self.track_stages and raw.requires_grad:
|
103 |
+
for stage in self.stages.values():
|
104 |
+
stage.retain_grad()
|
105 |
+
|
106 |
+
self.buffer['processed_rgb'] = rgb
|
107 |
+
|
108 |
+
return rgb
|
109 |
+
|
110 |
+
|
111 |
+
def add_additive_layer(processor):
|
112 |
+
processor.additive_layer = nn.Parameter(torch.zeros((1, 3, 256, 256)))
|
113 |
+
# processor.additive_layer = nn.Parameter(0.001 * torch.randn((1, 3, 256, 256)))
|
114 |
+
|
115 |
+
|
116 |
+
class ParametrizedProcessing(nn.Module):
|
117 |
+
def __init__(self, camera_parameters, track_stages=False, batch_norm_output=True):
|
118 |
+
super().__init__()
|
119 |
+
self.stages = None
|
120 |
+
self.buffer = None
|
121 |
+
self.track_stages = track_stages
|
122 |
+
|
123 |
+
black_level, white_balance, colour_matrix = camera_parameters
|
124 |
+
|
125 |
+
self.black_level = nn.Parameter(torch.as_tensor(black_level))
|
126 |
+
self.white_balance = nn.Parameter(torch.as_tensor(white_balance).reshape(1, 3))
|
127 |
+
self.colour_correction = nn.Parameter(torch.as_tensor(colour_matrix).reshape(3, 3))
|
128 |
+
|
129 |
+
self.gamma_correct = nn.Parameter(torch.Tensor([2.2]))
|
130 |
+
|
131 |
+
self.debayer = Debayer()
|
132 |
+
|
133 |
+
self.sharpening_filter = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
|
134 |
+
self.sharpening_filter.weight.data[0][0] = K_SHARP.clone()
|
135 |
+
|
136 |
+
self.gaussian_blur = nn.Conv2d(1, 1, kernel_size=5, padding=2, padding_mode='reflect', bias=False)
|
137 |
+
self.gaussian_blur.weight.data[0][0] = K_BLUR.clone()
|
138 |
+
|
139 |
+
self.batch_norm = nn.BatchNorm2d(3, affine=False) if batch_norm_output else None
|
140 |
+
|
141 |
+
self.register_buffer('M_RGB_2_YUV', M_RGB_2_YUV.clone())
|
142 |
+
self.register_buffer('M_YUV_2_RGB', M_YUV_2_RGB.clone())
|
143 |
+
|
144 |
+
self.additive_layer = None # this can be added in later
|
145 |
+
|
146 |
+
def forward(self, raw):
|
147 |
+
assert raw.ndim == 3, f"needs dims (B, H, W), got {raw.shape}"
|
148 |
+
|
149 |
+
self.stages = {}
|
150 |
+
self.buffer = {}
|
151 |
+
|
152 |
+
# self.stages['raw'] = raw
|
153 |
+
|
154 |
+
rgb = raw2rgb(raw, black_level=self.black_level, reduce_size=False)
|
155 |
+
rgb = rgb.contiguous()
|
156 |
+
self.stages['demosaic'] = rgb
|
157 |
+
|
158 |
+
rgb = self.debayer(rgb)
|
159 |
+
# self.stages['debayer'] = rgb
|
160 |
+
|
161 |
+
rgb = torch.einsum('bchw,kc->bchw', rgb, self.white_balance).contiguous()
|
162 |
+
rgb = torch.einsum('bchw,kc->bkhw', rgb, self.colour_correction).contiguous()
|
163 |
+
self.stages['color_correct'] = rgb
|
164 |
+
|
165 |
+
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
|
166 |
+
yuv[:, [0], ...] = self.sharpening_filter(yuv[:, [0], ...])
|
167 |
+
|
168 |
+
if self.track_stages: # keep stage in computational graph for grad information
|
169 |
+
rgb = torch.einsum('bchw,kc->bkhw', yuv.clone(), self.M_YUV_2_RGB).contiguous()
|
170 |
+
self.stages['sharpening'] = rgb
|
171 |
+
yuv = torch.einsum('bchw,kc->bkhw', rgb, self.M_RGB_2_YUV).contiguous()
|
172 |
+
|
173 |
+
yuv[:, [0], ...] = self.gaussian_blur(yuv[:, [0], ...])
|
174 |
+
rgb = torch.einsum('bchw,kc->bkhw', yuv, self.M_YUV_2_RGB).contiguous()
|
175 |
+
self.stages['gaussian'] = rgb
|
176 |
+
|
177 |
+
rgb = torch.clip(rgb, 1e-5, 1)
|
178 |
+
self.stages['clipped'] = rgb
|
179 |
+
|
180 |
+
rgb = torch.exp((1 / self.gamma_correct) * torch.log(rgb))
|
181 |
+
self.stages['gamma_correct'] = rgb
|
182 |
+
|
183 |
+
if self.additive_layer is not None:
|
184 |
+
rgb = rgb + self.additive_layer
|
185 |
+
self.stages['noise'] = rgb
|
186 |
+
|
187 |
+
if self.batch_norm is not None:
|
188 |
+
rgb = self.batch_norm(rgb)
|
189 |
+
|
190 |
+
if self.track_stages and raw.requires_grad:
|
191 |
+
for stage in self.stages.values():
|
192 |
+
stage.retain_grad()
|
193 |
+
|
194 |
+
self.buffer['processed_rgb'] = rgb
|
195 |
+
|
196 |
+
return rgb
|
197 |
+
|
198 |
+
|
199 |
+
class Debayer(nn.Conv2d):
|
200 |
+
def __init__(self):
|
201 |
+
super().__init__(3, 3, kernel_size=3, padding=1, padding_mode='reflect', bias=False) # default_pipeline uses 'replicate'
|
202 |
+
self.weight.data.fill_(0)
|
203 |
+
self.weight.data[0, 0] = K_RB.clone()
|
204 |
+
self.weight.data[1, 1] = K_G.clone()
|
205 |
+
self.weight.data[2, 2] = K_RB.clone()
|
206 |
+
|
207 |
+
|
208 |
+
def raw2rgb(raw, black_level=None, reduce_size=True, out_channels=3):
|
209 |
+
"""transform raw image with 1 channel to rgb with 3 channels
|
210 |
+
Args:
|
211 |
+
raw (Tensor): raw Tensor of shape (B, H, W)
|
212 |
+
black_level (iterable, optional): RGGB black level values to subtract
|
213 |
+
reduce_size (bool, optional): if False, the output image will have the same height and width
|
214 |
+
as the raw input, i.e. (B, C, H, W), empty values are filled with zeros.
|
215 |
+
if True, the output dimensions are reduced by half (B, C, H//2, W//2),
|
216 |
+
the two green channels are averaged.
|
217 |
+
out_channels (int, optional): number of output channels. One of {3, 4}.
|
218 |
+
"""
|
219 |
+
assert out_channels in [3, 4]
|
220 |
+
if black_level is None:
|
221 |
+
black_level = [0, 0, 0, 0]
|
222 |
+
Bch, H, W = raw.shape
|
223 |
+
R = raw[:, 0::2, 0::2] - black_level[0] # R
|
224 |
+
G1 = raw[:, 0::2, 1::2] - black_level[1] # G
|
225 |
+
G2 = raw[:, 1::2, 0::2] - black_level[2] # G
|
226 |
+
B = raw[:, 1::2, 1::2] - black_level[3] # B
|
227 |
+
if reduce_size:
|
228 |
+
rgb = torch.zeros((Bch, out_channels, H // 2, W // 2), device=raw.device)
|
229 |
+
if out_channels == 3:
|
230 |
+
rgb[:, 0, :, :] = R
|
231 |
+
rgb[:, 1, :, :] = (G1 + G2) / 2
|
232 |
+
rgb[:, 2, :, :] = B
|
233 |
+
elif out_channels == 4:
|
234 |
+
rgb[:, 0, :, :] = R
|
235 |
+
rgb[:, 1, :, :] = G1
|
236 |
+
rgb[:, 2, :, :] = G2
|
237 |
+
rgb[:, 3, :, :] = B
|
238 |
+
else:
|
239 |
+
rgb = torch.zeros((Bch, out_channels, H, W), device=raw.device)
|
240 |
+
if out_channels == 3:
|
241 |
+
rgb[:, 0, 0::2, 0::2] = R
|
242 |
+
rgb[:, 1, 0::2, 1::2] = G1
|
243 |
+
rgb[:, 1, 1::2, 0::2] = G2
|
244 |
+
rgb[:, 2, 1::2, 1::2] = B
|
245 |
+
elif out_channels == 4:
|
246 |
+
rgb[:, 0, 0::2, 0::2] = R
|
247 |
+
rgb[:, 1, 0::2, 1::2] = G1
|
248 |
+
rgb[:, 2, 1::2, 0::2] = G2
|
249 |
+
rgb[:, 3, 1::2, 1::2] = B
|
250 |
+
return rgb
|
251 |
+
|
252 |
+
|
253 |
+
# pipeline validation
|
254 |
+
if __name__ == "__main__":
|
255 |
+
|
256 |
+
import torch
|
257 |
+
import numpy as np
|
258 |
+
|
259 |
+
if not os.path.exists('README.md'):
|
260 |
+
os.chdir('..')
|
261 |
+
|
262 |
+
import matplotlib.pyplot as plt
|
263 |
+
from dataset import get_dataset
|
264 |
+
from utils.base import np2torch, torch2np
|
265 |
+
|
266 |
+
from utils.debug import debug
|
267 |
+
from processing.pipeline_numpy import processing as default_processing
|
268 |
+
|
269 |
+
raw_dataset = get_dataset('DS')
|
270 |
+
loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1)
|
271 |
+
batch_raw, batch_mask = next(iter(loader))
|
272 |
+
|
273 |
+
# torch proc
|
274 |
+
camera_parameters = raw_dataset.camera_parameters
|
275 |
+
black_level = camera_parameters[0]
|
276 |
+
|
277 |
+
proc = ParametrizedProcessing(camera_parameters)
|
278 |
+
|
279 |
+
batch_rgb = proc(batch_raw)
|
280 |
+
rgb = batch_rgb[0]
|
281 |
+
|
282 |
+
# numpy proc
|
283 |
+
raw_img = batch_raw[0]
|
284 |
+
numpy_raw = torch2np(raw_img)
|
285 |
+
|
286 |
+
default_rgb = default_processing(numpy_raw, *camera_parameters,
|
287 |
+
sharpening='sharpening_filter', denoising='gaussian_denoising')
|
288 |
+
|
289 |
+
rgb_valid = np2torch(default_rgb)
|
290 |
+
|
291 |
+
print("pipeline norm difference:", (rgb - rgb_valid).norm().item())
|
292 |
+
|
293 |
+
rgb_mosaic = raw2rgb(batch_raw, reduce_size=False).squeeze()
|
294 |
+
rgb_reduced = raw2rgb(batch_raw, reduce_size=True).squeeze()
|
295 |
+
|
296 |
+
plt.figure(figsize=(16, 8))
|
297 |
+
plt.subplot(151)
|
298 |
+
plt.title('Raw')
|
299 |
+
plt.imshow(torch2np(raw_img))
|
300 |
+
plt.subplot(152)
|
301 |
+
plt.title('RGB Mosaic')
|
302 |
+
plt.imshow(torch2np(rgb_mosaic))
|
303 |
+
plt.subplot(153)
|
304 |
+
plt.title('RGB Reduced')
|
305 |
+
plt.imshow(torch2np(rgb_reduced))
|
306 |
+
plt.subplot(154)
|
307 |
+
plt.title('Torch Pipeline')
|
308 |
+
plt.imshow(torch2np(rgb))
|
309 |
+
plt.subplot(155)
|
310 |
+
plt.title('Default Pipeline')
|
311 |
+
plt.imshow(torch2np(rgb_valid))
|
312 |
+
plt.show()
|
313 |
+
|
314 |
+
# assert rgb.allclose(rgb_valid)
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scikit_image
|
2 |
+
pandas
|
3 |
+
scipy
|
4 |
+
numpy
|
5 |
+
matplotlib
|
6 |
+
b2sdk
|
7 |
+
colour_demosaicing
|
8 |
+
gradio
|
9 |
+
ipython
|
10 |
+
mlflow
|
11 |
+
Pillow
|
12 |
+
pytorch_toolbelt
|
13 |
+
rawpy
|
14 |
+
scikit_learn
|
15 |
+
segmentation_models_pytorch
|
16 |
+
tifffile
|
17 |
+
torch==1.9.0
|
18 |
+
torchvision==0.10.0
|
train.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import copy
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import optim
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
import mlflow.pytorch
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from torchvision.models import resnet18
|
13 |
+
import torchvision.transforms as T
|
14 |
+
from pytorch_lightning.metrics.functional import accuracy
|
15 |
+
import pytorch_lightning as pl
|
16 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
17 |
+
|
18 |
+
from utils.base import display_mlflow_run_info, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std
|
19 |
+
from utils.debug import debug
|
20 |
+
from utils.dataset_utils import k_fold
|
21 |
+
from utils.augmentation import get_augmentation
|
22 |
+
from dataset import Subset, get_dataset
|
23 |
+
|
24 |
+
from processing.pipeline_numpy import RawProcessingPipeline
|
25 |
+
from processing.pipeline_torch import add_additive_layer, raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing
|
26 |
+
|
27 |
+
from model import log_tensor, resnet_model, LitModel, TrackImagesCallback
|
28 |
+
|
29 |
+
import segmentation_models_pytorch as smp
|
30 |
+
|
31 |
+
from utils.ssim import SSIM
|
32 |
+
|
33 |
+
# args to set up task
|
34 |
+
parser = argparse.ArgumentParser(description="classification_task")
|
35 |
+
parser.add_argument("--tracking_uri", type=str,
|
36 |
+
default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS')
|
37 |
+
parser.add_argument("--processor_uri", type=str, default=None,
|
38 |
+
help='URI of the processing model (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/processing-model)')
|
39 |
+
parser.add_argument("--classifier_uri", type=str, default=None,
|
40 |
+
help='URI of the net (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/prediction-model)')
|
41 |
+
parser.add_argument("--state_dict_uri", type=str,
|
42 |
+
default=None, help='URI of the indices you want to load (e.g. s3://mlflow-artifacts-601883093460/7/4326da05aca54107be8c554de0674a14/artifacts/training')
|
43 |
+
|
44 |
+
parser.add_argument("--experiment_name", type=str,
|
45 |
+
default='classification learnable pipeline', help='Specify the experiment you are running, e.g. end2end segmentation')
|
46 |
+
parser.add_argument("--run_name", type=str,
|
47 |
+
default='test run', help='Specify the name of your run')
|
48 |
+
|
49 |
+
parser.add_argument("--log_model", type=str2bool, default=True, help='Enables model logging')
|
50 |
+
parser.add_argument("--save_locally", action='store_true',
|
51 |
+
help='Model will be saved locally if action is taken') # TODO: bypass mlflow
|
52 |
+
|
53 |
+
parser.add_argument("--track_processing", action='store_true',
|
54 |
+
help='Save images after each trasformation of the pipeline for the test set')
|
55 |
+
parser.add_argument("--track_processing_gradients", action='store_true',
|
56 |
+
help='Save images of gradients after each trasformation of the pipeline for the test set')
|
57 |
+
parser.add_argument("--track_save_tensors", action='store_true',
|
58 |
+
help='Save the torch tensors after each trasformation of the pipeline for the test set')
|
59 |
+
parser.add_argument("--track_predictions", action='store_true',
|
60 |
+
help='Save images after each trasformation of the pipeline for the test set + input gradient')
|
61 |
+
parser.add_argument("--track_n_images", default=5,
|
62 |
+
help='Track the n first elements of dataset. Only used for args.track_processing=True')
|
63 |
+
parser.add_argument("--track_every_epoch", action='store_true', help='Track images every epoch or once after training')
|
64 |
+
|
65 |
+
# args to create dataset
|
66 |
+
parser.add_argument("--seed", type=int, default=1, help='Global seed')
|
67 |
+
parser.add_argument("--dataset", type=str, default='Microscopy',
|
68 |
+
choices=["Drone", "DroneSegmentation", "Microscopy"], help='Select dataset')
|
69 |
+
|
70 |
+
parser.add_argument("--n_splits", type=int, default=1, help='Number of splits used for training')
|
71 |
+
parser.add_argument("--train_size", type=float, default=0.8, help='Fraction of training points in dataset')
|
72 |
+
|
73 |
+
# args for training
|
74 |
+
parser.add_argument("--lr", type=float, default=1e-5, help="learning rate used for training")
|
75 |
+
parser.add_argument("--epochs", type=int, default=3, help="numper of epochs")
|
76 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Training batch size")
|
77 |
+
parser.add_argument("--augmentation", type=str, default='none',
|
78 |
+
choices=["none", "weak", "strong"], help="Applies augmentation to training")
|
79 |
+
parser.add_argument("--augmentation_on_valid_epoch", action='store_true',
|
80 |
+
help='Track images every epoch or once after training') # TODO: implement, actually should be disabled by default for 'val' and 'test
|
81 |
+
parser.add_argument("--check_val_every_n_epoch", type=int, default=1)
|
82 |
+
|
83 |
+
# args to specify the processing
|
84 |
+
parser.add_argument("--processing_mode", type=str, default="parametrized",
|
85 |
+
choices=["parametrized", "static", "neural_network", "none"],
|
86 |
+
help="Which type of raw to rgb processing should be used")
|
87 |
+
|
88 |
+
# args to specify model
|
89 |
+
parser.add_argument("--classifier_network", type=str, default='ResNet18',
|
90 |
+
help='Type of pretrained network') # TODO: implement different choices
|
91 |
+
parser.add_argument("--classifier_pretrained", action='store_true',
|
92 |
+
help='Whether to use a pre-trained model or not')
|
93 |
+
parser.add_argument("--smp_encoder", type=str, default='resnet34', help='segmentation model encoder')
|
94 |
+
|
95 |
+
parser.add_argument("--freeze_processor", action='store_true', help="Freeze raw to rgb processing model weights")
|
96 |
+
parser.add_argument("--freeze_classifier", action='store_true', help="Freeze classification model weights")
|
97 |
+
|
98 |
+
# args to specify static pipeline transformations
|
99 |
+
parser.add_argument("--sp_debayer", type=str, default='bilinear',
|
100 |
+
choices=['bilinear', 'malvar2004', 'menon2007'], help="Specify algorithm used as debayer")
|
101 |
+
parser.add_argument("--sp_sharpening", type=str, default='sharpening_filter',
|
102 |
+
choices=['sharpening_filter', 'unsharp_masking'], help="Specify algorithm used for sharpening")
|
103 |
+
parser.add_argument("--sp_denoising", type=str, default='gaussian_denoising',
|
104 |
+
choices=['gaussian_denoising', 'median_denoising', 'fft_denoising'], help="Specify algorithm used for denoising")
|
105 |
+
|
106 |
+
# args to choose training mode
|
107 |
+
parser.add_argument("--adv_training", action='store_true', help="Enable adversarial training")
|
108 |
+
parser.add_argument("--adv_aux_weight", type=float, default=1, help="Weighting of the adversarial auxilliary loss")
|
109 |
+
parser.add_argument("--adv_aux_loss", type=str, default='ssim', choices=['l2', 'ssim'],
|
110 |
+
help="Type of adversarial auxilliary regularization loss")
|
111 |
+
parser.add_argument("--adv_noise_layer", action='store_true', help="Adds an additive layer to Parametrized Processing")
|
112 |
+
parser.add_argument("--adv_track_differences", action='store_true', help='Save difference to default pipeline')
|
113 |
+
parser.add_argument('--adv_parameters', choices=['all', 'black_level', 'white_balance',
|
114 |
+
'colour_correction', 'gamma_correct', 'sharpening_filter', 'gaussian_blur', 'additive_layer'])
|
115 |
+
|
116 |
+
parser.add_argument("--cache_downloaded_models", type=str2bool, default=True)
|
117 |
+
|
118 |
+
parser.add_argument('--test_run', action='store_true')
|
119 |
+
|
120 |
+
|
121 |
+
args = parser.parse_args()
|
122 |
+
|
123 |
+
os.makedirs('results', exist_ok=True)
|
124 |
+
|
125 |
+
|
126 |
+
def run_train(args):
|
127 |
+
|
128 |
+
print(args)
|
129 |
+
|
130 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
131 |
+
training_mode = 'adversarial' if args.adv_training else 'default'
|
132 |
+
|
133 |
+
# set tracking uri, this is the address of the mlflow server where light experimental data will be stored
|
134 |
+
mlflow.set_tracking_uri(args.tracking_uri)
|
135 |
+
mlflow.set_experiment(args.experiment_name)
|
136 |
+
os.environ["AWS_ACCESS_KEY_ID"] = "#TODO: fill in your aws access key id for mlflow server here"
|
137 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = "#TODO: fill in your aws secret access key for mlflow server here"
|
138 |
+
|
139 |
+
# dataset
|
140 |
+
|
141 |
+
dataset = get_dataset(args.dataset)
|
142 |
+
|
143 |
+
print(f'dataset: {type(dataset).__name__}[{len(dataset)}]')
|
144 |
+
print(f'task: {dataset.task}')
|
145 |
+
print(f'mode: {training_mode} training')
|
146 |
+
print(f'# cross-validation subsets: {args.n_splits}')
|
147 |
+
pl.seed_everything(args.seed)
|
148 |
+
idxs_kfold = k_fold(dataset, n_splits=args.n_splits, seed=args.seed, train_size=args.train_size)
|
149 |
+
|
150 |
+
with mlflow.start_run(run_name=args.run_name) as parent_run:
|
151 |
+
|
152 |
+
for k_iter, idxs in enumerate(idxs_kfold):
|
153 |
+
|
154 |
+
print(f"K_fold subset: {k_iter+1}/{args.n_splits}")
|
155 |
+
|
156 |
+
if args.processing_mode == 'static':
|
157 |
+
if args.dataset == "Drone" or args.dataset == "DroneSegmentation":
|
158 |
+
mean = torch.tensor([0.35, 0.36, 0.35])
|
159 |
+
std = torch.tensor([0.12, 0.11, 0.12])
|
160 |
+
elif args.dataset == "Microscopy":
|
161 |
+
mean = torch.tensor([0.91, 0.84, 0.94])
|
162 |
+
std = torch.tensor([0.08, 0.12, 0.05])
|
163 |
+
|
164 |
+
dataset.transform = T.Compose([RawProcessingPipeline(
|
165 |
+
camera_parameters=dataset.camera_parameters,
|
166 |
+
debayer=args.sp_debayer,
|
167 |
+
sharpening=args.sp_sharpening,
|
168 |
+
denoising=args.sp_denoising,
|
169 |
+
), T.Normalize(mean, std)])
|
170 |
+
# XXX: Not clean
|
171 |
+
|
172 |
+
processor = nn.Identity()
|
173 |
+
|
174 |
+
if args.processor_uri is not None and args.processing_mode != 'none':
|
175 |
+
print('Fetching processor: ', end='')
|
176 |
+
model = fetch_from_mlflow(args.processor_uri, use_cache=args.cache_downloaded_models)
|
177 |
+
processor = model.processor
|
178 |
+
for param in processor.parameters():
|
179 |
+
param.requires_grad = True
|
180 |
+
model.processor = None
|
181 |
+
del model
|
182 |
+
else:
|
183 |
+
print(f'processing_mode: {args.processing_mode}')
|
184 |
+
normalize_mosaic = None # normalize after raw has been passed to raw2rgb
|
185 |
+
if args.dataset == "Microscopy":
|
186 |
+
mosaic_mean = [0.5663, 0.1401, 0.0731]
|
187 |
+
mosaic_std = [0.097, 0.0423, 0.008]
|
188 |
+
normalize_mosaic = T.Normalize(mosaic_mean, mosaic_std)
|
189 |
+
|
190 |
+
track_stages = args.track_processing or args.track_processing_gradients
|
191 |
+
if args.processing_mode == 'parametrized':
|
192 |
+
processor = ParametrizedProcessing(
|
193 |
+
camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True,
|
194 |
+
# noise_layer=args.adv_noise_layer, # this has to be added manually afterwards for when a model is loaded that doesn't have one yet
|
195 |
+
)
|
196 |
+
|
197 |
+
elif args.processing_mode == 'neural_network':
|
198 |
+
processor = NNProcessing(track_stages=track_stages,
|
199 |
+
normalize_mosaic=normalize_mosaic, batch_norm_output=True)
|
200 |
+
elif args.processing_mode == 'none':
|
201 |
+
processor = RawToRGB(reduce_size=True, out_channels=3, track_stages=track_stages,
|
202 |
+
normalize_mosaic=normalize_mosaic)
|
203 |
+
|
204 |
+
if args.classifier_uri: # fetch classifier
|
205 |
+
print('Fetching classifier: ', end='')
|
206 |
+
model = fetch_from_mlflow(args.classifier_uri, use_cache=args.cache_downloaded_models)
|
207 |
+
classifier = model.classifier
|
208 |
+
model.classifier = None
|
209 |
+
del model
|
210 |
+
else:
|
211 |
+
if dataset.task == 'classification':
|
212 |
+
classifier = resnet_model(
|
213 |
+
model=resnet18,
|
214 |
+
pretrained=args.classifier_pretrained,
|
215 |
+
in_channels=3,
|
216 |
+
fc_out_features=len(dataset.classes)
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
# XXX: add other network choices to args.smp_network (FPN) and args.network
|
220 |
+
classifier = smp.UnetPlusPlus(
|
221 |
+
encoder_name=args.smp_encoder,
|
222 |
+
encoder_depth=5,
|
223 |
+
encoder_weights='imagenet',
|
224 |
+
in_channels=3,
|
225 |
+
classes=1,
|
226 |
+
activation=None,
|
227 |
+
)
|
228 |
+
|
229 |
+
if args.freeze_processor and len(list(iter(processor.parameters()))) == 0:
|
230 |
+
print('Note: freezing processor without parameters.')
|
231 |
+
assert not (args.freeze_processor and args.freeze_classifier), 'Likely no parameters to train.'
|
232 |
+
|
233 |
+
if dataset.task == 'classification':
|
234 |
+
loss = nn.CrossEntropyLoss()
|
235 |
+
metrics = [accuracy]
|
236 |
+
else:
|
237 |
+
# loss = utils.base.smp_get_loss(args.smp_loss) # XXX: add other losses to args.smp_loss
|
238 |
+
loss = smp.losses.DiceLoss(mode='binary', from_logits=True)
|
239 |
+
metrics = [smp.utils.metrics.IoU()]
|
240 |
+
|
241 |
+
loss_aux = None
|
242 |
+
|
243 |
+
if args.adv_training:
|
244 |
+
|
245 |
+
assert args.processing_mode == 'parametrized', f"Processing mode ({args.processing_mode}) should be set to 'parametrized' for adversarial training"
|
246 |
+
assert args.freeze_classifier, "Classifier should be frozen for adversarial training"
|
247 |
+
assert not args.freeze_processor, "Processor should not be frozen for adversarial training"
|
248 |
+
|
249 |
+
processor_default = copy.deepcopy(processor)
|
250 |
+
processor_default.track_stages = args.track_processing
|
251 |
+
processor_default.eval()
|
252 |
+
processor_default.to(DEVICE)
|
253 |
+
# debug(processor_default)
|
254 |
+
for p in processor_default.parameters():
|
255 |
+
p.requires_grad = False
|
256 |
+
|
257 |
+
if args.adv_noise_layer:
|
258 |
+
add_additive_layer(processor)
|
259 |
+
|
260 |
+
def l2_regularization(x, y):
|
261 |
+
return ((x - y) ** 2).sum()
|
262 |
+
# return (x - y).norm()
|
263 |
+
|
264 |
+
if args.adv_aux_loss == 'l2':
|
265 |
+
regularization = l2_regularization
|
266 |
+
elif args.adv_aux_loss == 'ssim':
|
267 |
+
regularization = SSIM(window_size=11)
|
268 |
+
else:
|
269 |
+
NotImplementedError(args.adv_aux_loss)
|
270 |
+
|
271 |
+
class AuxLoss(nn.Module):
|
272 |
+
def __init__(self, loss_aux, weight=1):
|
273 |
+
super().__init__()
|
274 |
+
self.loss_aux = loss_aux
|
275 |
+
self.weight = weight
|
276 |
+
|
277 |
+
def forward(self, x):
|
278 |
+
with torch.no_grad():
|
279 |
+
x_reference = processor_default(x)
|
280 |
+
x_processed = processor.buffer['processed_rgb']
|
281 |
+
return self.weight * self.loss_aux(x_reference, x_processed)
|
282 |
+
|
283 |
+
class WeightedLoss(nn.Module):
|
284 |
+
def __init__(self, loss, weight=1):
|
285 |
+
super().__init__()
|
286 |
+
self.loss = loss
|
287 |
+
self.weight = weight
|
288 |
+
|
289 |
+
def forward(self, x, y):
|
290 |
+
return self.weight * self.loss(x, y)
|
291 |
+
|
292 |
+
def __repr__(self):
|
293 |
+
return f'{self.weight} * {get_name(self.loss)}'
|
294 |
+
|
295 |
+
loss = WeightedLoss(loss=loss, weight=-1)
|
296 |
+
# loss = WeightedLoss(loss=nn.CrossEntropyLoss(), weight=0)
|
297 |
+
loss_aux = AuxLoss(
|
298 |
+
loss_aux=regularization,
|
299 |
+
weight=args.adv_aux_weight,
|
300 |
+
)
|
301 |
+
|
302 |
+
augmentation = get_augmentation(args.augmentation)
|
303 |
+
|
304 |
+
model = LitModel(
|
305 |
+
classifier=classifier,
|
306 |
+
processor=processor,
|
307 |
+
loss=loss,
|
308 |
+
lr=args.lr,
|
309 |
+
loss_aux=loss_aux,
|
310 |
+
adv_training=args.adv_training,
|
311 |
+
adv_parameters=args.adv_parameters,
|
312 |
+
metrics=metrics,
|
313 |
+
augmentation=augmentation,
|
314 |
+
is_segmentation_task=dataset.task == 'segmentation',
|
315 |
+
freeze_classifier=args.freeze_classifier,
|
316 |
+
freeze_processor=args.freeze_processor,
|
317 |
+
)
|
318 |
+
|
319 |
+
# get train_set_dict
|
320 |
+
if args.state_dict_uri:
|
321 |
+
state_dict = mlflow.pytorch.load_state_dict(args.state_dict_uri)
|
322 |
+
train_indices = state_dict['train_indices']
|
323 |
+
valid_indices = state_dict['valid_indices']
|
324 |
+
else:
|
325 |
+
train_indices = idxs[0]
|
326 |
+
valid_indices = idxs[1]
|
327 |
+
state_dict = vars(args).copy()
|
328 |
+
|
329 |
+
track_indices = list(range(args.track_n_images))
|
330 |
+
|
331 |
+
if dataset.task == 'classification':
|
332 |
+
state_dict['classes'] = dataset.classes
|
333 |
+
state_dict['device'] = DEVICE
|
334 |
+
state_dict['train_indices'] = train_indices
|
335 |
+
state_dict['valid_indices'] = valid_indices
|
336 |
+
state_dict['elements in train set'] = len(train_indices)
|
337 |
+
state_dict['elements in test set'] = len(valid_indices)
|
338 |
+
|
339 |
+
if args.test_run:
|
340 |
+
train_indices = train_indices[:args.batch_size]
|
341 |
+
valid_indices = valid_indices[:args.batch_size]
|
342 |
+
|
343 |
+
train_set = Subset(dataset, indices=train_indices)
|
344 |
+
valid_set = Subset(dataset, indices=valid_indices)
|
345 |
+
track_set = Subset(dataset, indices=track_indices)
|
346 |
+
|
347 |
+
train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=16, shuffle=True)
|
348 |
+
valid_loader = DataLoader(valid_set, batch_size=args.batch_size, num_workers=16, shuffle=False)
|
349 |
+
track_loader = DataLoader(track_set, batch_size=args.batch_size, num_workers=16, shuffle=False)
|
350 |
+
|
351 |
+
with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run:
|
352 |
+
|
353 |
+
# mlflow.pytorch.autolog(silent=True)
|
354 |
+
|
355 |
+
if k_iter == 0:
|
356 |
+
display_mlflow_run_info(child_run)
|
357 |
+
|
358 |
+
mlflow.pytorch.log_state_dict(state_dict, artifact_path=None)
|
359 |
+
|
360 |
+
hparams = {
|
361 |
+
'dataset': args.dataset,
|
362 |
+
'processing_mode': args.processing_mode,
|
363 |
+
'training_mode': training_mode,
|
364 |
+
}
|
365 |
+
if training_mode == 'adversarial':
|
366 |
+
hparams['adv_aux_weight'] = args.adv_aux_weight
|
367 |
+
hparams['adv_aux_loss'] = args.adv_aux_loss
|
368 |
+
|
369 |
+
mlflow.log_params(hparams)
|
370 |
+
|
371 |
+
with open('results/state_dict.txt', 'w') as f:
|
372 |
+
f.write('python ' + ' '.join(sys.argv) + '\n')
|
373 |
+
f.write('\n'.join([f'{k}={v}' for k, v in state_dict.items()]))
|
374 |
+
mlflow.log_artifact('results/state_dict.txt', artifact_path=None)
|
375 |
+
|
376 |
+
mlf_logger = pl.loggers.MLFlowLogger(experiment_name=args.experiment_name,
|
377 |
+
tracking_uri=args.tracking_uri,)
|
378 |
+
mlf_logger._run_id = child_run.info.run_id
|
379 |
+
|
380 |
+
reference_processor = processor_default if args.adv_training and args.adv_track_differences else None
|
381 |
+
|
382 |
+
callbacks = []
|
383 |
+
if args.track_processing:
|
384 |
+
callbacks += [TrackImagesCallback(track_loader,
|
385 |
+
reference_processor,
|
386 |
+
track_every_epoch=args.track_every_epoch,
|
387 |
+
track_processing=args.track_processing,
|
388 |
+
track_gradients=args.track_processing_gradients,
|
389 |
+
track_predictions=args.track_predictions,
|
390 |
+
save_tensors=args.track_save_tensors)]
|
391 |
+
|
392 |
+
# if True: #args.save_best:
|
393 |
+
# if dataset.task == 'classification':
|
394 |
+
#checkpoint_callback = ModelCheckpoint(pathmonitor="val_accuracy", mode='max')
|
395 |
+
# checkpoint_callback = ModelCheckpoint(dirpath=args.tracking_uri, save_top_k=1, verbose=True, monitor="val_accuracy", mode="max") #dirpath=args.tracking_uri,
|
396 |
+
# else:
|
397 |
+
# checkpoint_callback = ModelCheckpoint(monitor="val_iou_score")
|
398 |
+
#callbacks += [checkpoint_callback]
|
399 |
+
|
400 |
+
trainer = pl.Trainer(
|
401 |
+
gpus=1 if DEVICE == 'cuda' else 0,
|
402 |
+
min_epochs=args.epochs,
|
403 |
+
max_epochs=args.epochs,
|
404 |
+
logger=mlf_logger,
|
405 |
+
callbacks=callbacks,
|
406 |
+
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
407 |
+
# checkpoint_callback=True,
|
408 |
+
)
|
409 |
+
|
410 |
+
if args.log_model:
|
411 |
+
mlflow.pytorch.autolog(log_every_n_epoch=10)
|
412 |
+
print(f'model_uri="{mlflow.get_artifact_uri()}/model"')
|
413 |
+
|
414 |
+
t = trainer.fit(
|
415 |
+
model,
|
416 |
+
train_dataloader=train_loader,
|
417 |
+
val_dataloaders=valid_loader,
|
418 |
+
)
|
419 |
+
|
420 |
+
globals().update(locals()) # for convenient access
|
421 |
+
|
422 |
+
return model
|
423 |
+
|
424 |
+
|
425 |
+
if __name__ == '__main__':
|
426 |
+
model = run_train(args)
|
utils/augmentation.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torchvision.transforms as T
|
6 |
+
|
7 |
+
|
8 |
+
class RandomRotate90(): # Note: not the same as T.RandomRotation(90)
|
9 |
+
def __call__(self, x):
|
10 |
+
x = x.rot90(random.randint(0, 3), dims=(-1, -2))
|
11 |
+
return x
|
12 |
+
|
13 |
+
def __repr__(self):
|
14 |
+
return self.__class__.__name__
|
15 |
+
|
16 |
+
|
17 |
+
class AddGaussianNoise():
|
18 |
+
def __init__(self, std=0.01):
|
19 |
+
self.std = std
|
20 |
+
|
21 |
+
def __call__(self, x):
|
22 |
+
# noise = torch.randn_like(x) * self.std
|
23 |
+
# out = x + noise
|
24 |
+
# debug(x)
|
25 |
+
# debug(noise)
|
26 |
+
# debug(out)
|
27 |
+
return x + torch.randn_like(x) * self.std
|
28 |
+
|
29 |
+
def __repr__(self):
|
30 |
+
return self.__class__.__name__ + f'(std={self.std})'
|
31 |
+
|
32 |
+
|
33 |
+
def set_global_seed(seed):
|
34 |
+
torch.random.manual_seed(seed)
|
35 |
+
np.random.seed(seed % (2**32 - 1))
|
36 |
+
random.seed(seed)
|
37 |
+
|
38 |
+
|
39 |
+
class ComposeState(T.Compose):
|
40 |
+
def __init__(self, transforms):
|
41 |
+
self.transforms = []
|
42 |
+
self.mask_transforms = []
|
43 |
+
|
44 |
+
for t in transforms:
|
45 |
+
apply_for_mask = True
|
46 |
+
if isinstance(t, tuple):
|
47 |
+
t, apply_for_mask = t
|
48 |
+
self.transforms.append(t)
|
49 |
+
if apply_for_mask:
|
50 |
+
self.mask_transforms.append(t)
|
51 |
+
|
52 |
+
self.seed = None
|
53 |
+
|
54 |
+
# @debug
|
55 |
+
def __call__(self, x, retain_state=False, mask_transform=False):
|
56 |
+
if self.seed is not None: # retain previous state
|
57 |
+
set_global_seed(self.seed)
|
58 |
+
if retain_state: # save state for next call
|
59 |
+
self.seed = self.seed or torch.seed()
|
60 |
+
set_global_seed(self.seed)
|
61 |
+
else:
|
62 |
+
self.seed = None # reset / ignore state
|
63 |
+
|
64 |
+
transforms = self.transforms if not mask_transform else self.mask_transforms
|
65 |
+
for t in transforms:
|
66 |
+
x = t(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
augmentation_weak = ComposeState([
|
71 |
+
T.RandomHorizontalFlip(),
|
72 |
+
T.RandomVerticalFlip(),
|
73 |
+
RandomRotate90(),
|
74 |
+
])
|
75 |
+
|
76 |
+
|
77 |
+
augmentation_strong = ComposeState([
|
78 |
+
T.RandomHorizontalFlip(p=0.5),
|
79 |
+
T.RandomVerticalFlip(p=0.5),
|
80 |
+
T.RandomApply([T.RandomRotation(90)], p=0.5),
|
81 |
+
# (transform, apply_to_mask=True)
|
82 |
+
(T.RandomApply([AddGaussianNoise(std=0.0005)], p=0.5), False),
|
83 |
+
(T.RandomAdjustSharpness(0.5, p=0.5), False),
|
84 |
+
])
|
85 |
+
|
86 |
+
|
87 |
+
def get_augmentation(type):
|
88 |
+
if type == 'none':
|
89 |
+
return None
|
90 |
+
if type == 'weak':
|
91 |
+
return augmentation_weak
|
92 |
+
if type == 'strong':
|
93 |
+
return augmentation_strong
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == '__main__':
|
97 |
+
import os
|
98 |
+
if not os.path.exists('README.md'):
|
99 |
+
os.chdir('..')
|
100 |
+
|
101 |
+
# from utils.debug import debug
|
102 |
+
from dataset import get_dataset
|
103 |
+
import matplotlib.pyplot as plt
|
104 |
+
|
105 |
+
dataset = get_dataset('DS') # drone segmentation
|
106 |
+
img, mask = dataset[10]
|
107 |
+
mask = (mask + 0.2) / 1.2
|
108 |
+
|
109 |
+
plt.figure(figsize=(14, 8))
|
110 |
+
plt.subplot(121)
|
111 |
+
plt.imshow(img)
|
112 |
+
plt.subplot(122)
|
113 |
+
plt.imshow(mask)
|
114 |
+
plt.suptitle('no augmentation')
|
115 |
+
plt.show()
|
116 |
+
|
117 |
+
from utils.base import np2torch, torch2np
|
118 |
+
img, mask = np2torch(img), np2torch(mask)
|
119 |
+
|
120 |
+
# from utils.augmentation import get_augmentation
|
121 |
+
augmentation = get_augmentation('strong')
|
122 |
+
|
123 |
+
set_global_seed(1)
|
124 |
+
|
125 |
+
for i in range(1, 4):
|
126 |
+
plt.figure(figsize=(14, 8))
|
127 |
+
plt.subplot(121)
|
128 |
+
plt.imshow(torch2np(augmentation(img.unsqueeze(0), retain_state=True)).squeeze())
|
129 |
+
plt.subplot(122)
|
130 |
+
plt.imshow(torch2np(augmentation(mask.unsqueeze(0), mask_transform=True)).squeeze())
|
131 |
+
plt.suptitle(f'augmentation test {i}')
|
132 |
+
plt.show()
|
utils/base.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for other scripts
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
import random
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import mlflow
|
12 |
+
from mlflow.tracking import MlflowClient
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
from IPython.display import display, Markdown
|
16 |
+
|
17 |
+
from b2sdk.v1 import *
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
|
21 |
+
|
22 |
+
class SmartFormatter(argparse.HelpFormatter):
|
23 |
+
|
24 |
+
def _split_lines(self, text, width):
|
25 |
+
if text.startswith('R|'):
|
26 |
+
return text[2:].splitlines()
|
27 |
+
# this is the RawTextHelpFormatter._split_lines
|
28 |
+
return argparse.HelpFormatter._split_lines(self, text, width)
|
29 |
+
|
30 |
+
|
31 |
+
def str2bool(string):
|
32 |
+
return string == 'True'
|
33 |
+
|
34 |
+
|
35 |
+
def np2torch(nparray):
|
36 |
+
"""Convert numpy array to torch tensor
|
37 |
+
For array with more than 3 channels, it is better to use an input array in the format BxHxWxC
|
38 |
+
|
39 |
+
Args:
|
40 |
+
numpy array (ndarray) BxHxWxC
|
41 |
+
Returns:
|
42 |
+
torch tensor (tensor) BxCxHxW"""
|
43 |
+
|
44 |
+
tensor = torch.Tensor(nparray)
|
45 |
+
|
46 |
+
if tensor.ndim == 2:
|
47 |
+
return tensor
|
48 |
+
if tensor.ndim == 3:
|
49 |
+
height, width, channels = tensor.shape
|
50 |
+
if channels <= 3: # Single image with more channels (HxWxC)
|
51 |
+
return tensor.permute(2, 0, 1)
|
52 |
+
|
53 |
+
if tensor.ndim == 4: # More images with more channels (BxHxWxC)
|
54 |
+
return tensor.permute(0, 3, 1, 2)
|
55 |
+
|
56 |
+
return tensor
|
57 |
+
|
58 |
+
|
59 |
+
def torch2np(torchtensor):
|
60 |
+
"""Convert torch tensor to numpy array
|
61 |
+
For tensor with more than 3 channels or batch, it is better to use an input tensor in the format BxCxHxW
|
62 |
+
|
63 |
+
Args:
|
64 |
+
torch tensor (tensor) BxCxHxW
|
65 |
+
Returns:
|
66 |
+
numpy array (ndarray) BxHxWxC"""
|
67 |
+
|
68 |
+
ndarray = torchtensor.detach().cpu().numpy().astype(np.float32)
|
69 |
+
|
70 |
+
if ndarray.ndim == 3: # Single image with more channels (CxHxW)
|
71 |
+
channels, height, width = ndarray.shape
|
72 |
+
if channels <= 3:
|
73 |
+
return ndarray.transpose(1, 2, 0)
|
74 |
+
|
75 |
+
if ndarray.ndim == 4: # More images with more channels (BxCxHxW)
|
76 |
+
return ndarray.transpose(0, 2, 3, 1)
|
77 |
+
|
78 |
+
return ndarray
|
79 |
+
|
80 |
+
|
81 |
+
def set_random_seed(seed):
|
82 |
+
np.random.seed(seed) # cpu vars
|
83 |
+
torch.manual_seed(seed) # cpu vars
|
84 |
+
random.seed(seed) # Python
|
85 |
+
if torch.cuda.is_available():
|
86 |
+
torch.cuda.manual_seed(seed)
|
87 |
+
torch.cuda.manual_seed_all(seed) # gpu vars
|
88 |
+
torch.backends.cudnn.deterministic = True # needed
|
89 |
+
torch.backends.cudnn.benchmark = False
|
90 |
+
|
91 |
+
|
92 |
+
def normalize(img):
|
93 |
+
"""Normalize images
|
94 |
+
|
95 |
+
Args:
|
96 |
+
imgs (ndarray): image to normalize --> size: (Height,Width,Channels)
|
97 |
+
Returns:
|
98 |
+
normalized (ndarray): normalized image
|
99 |
+
mu (ndarray): mean
|
100 |
+
sigma (ndarray): standard deviation
|
101 |
+
"""
|
102 |
+
|
103 |
+
img = img.astype(float)
|
104 |
+
|
105 |
+
if len(img.shape) == 2:
|
106 |
+
img = img[:, :, np.newaxis]
|
107 |
+
|
108 |
+
height, width, channels = img.shape
|
109 |
+
|
110 |
+
mu, sigma = np.empty(channels), np.empty(channels)
|
111 |
+
|
112 |
+
for ch in range(channels):
|
113 |
+
temp_mu = img[:, :, ch].mean()
|
114 |
+
temp_sigma = img[:, :, ch].std()
|
115 |
+
|
116 |
+
img[:, :, ch] = (img[:, :, ch] - temp_mu) / (temp_sigma + 1e-4)
|
117 |
+
|
118 |
+
mu[ch] = temp_mu
|
119 |
+
sigma[ch] = temp_sigma
|
120 |
+
|
121 |
+
return img, mu, sigma
|
122 |
+
|
123 |
+
|
124 |
+
def b2_list_files(folder=''):
|
125 |
+
bucket = get_b2_bucket()
|
126 |
+
for file_info, _ in bucket.ls(folder, show_versions=False):
|
127 |
+
print(file_info.file_name)
|
128 |
+
|
129 |
+
|
130 |
+
def get_b2_bucket():
|
131 |
+
bucket_name = 'perturbed-minds'
|
132 |
+
application_key_id = '003d6b042de536a0000000008'
|
133 |
+
application_key = 'K003HMNxnoa91Dy9c0V8JVCKNUnwR9U'
|
134 |
+
info = InMemoryAccountInfo()
|
135 |
+
b2_api = B2Api(info)
|
136 |
+
b2_api.authorize_account('production', application_key_id, application_key)
|
137 |
+
bucket = b2_api.get_bucket_by_name(bucket_name)
|
138 |
+
return bucket
|
139 |
+
|
140 |
+
|
141 |
+
def b2_download_folder(b2_dir, local_dir, force_download=False, mirror_folder=True):
|
142 |
+
"""Downloads a folder from the b2 bucket and optionally cleans
|
143 |
+
up files that are no longer on the server
|
144 |
+
|
145 |
+
Args:
|
146 |
+
b2_dir (str): path to folder on the b2 server
|
147 |
+
local_dir (str): path to folder on the local machine
|
148 |
+
force_download (bool, optional): force the download, if set to `False`,
|
149 |
+
files with matching names on the local machine will be skipped
|
150 |
+
mirror_folder (bool, optional): if set to `True`, files that are found in
|
151 |
+
the local directory, but are not on the server will be deleted
|
152 |
+
"""
|
153 |
+
bucket = get_b2_bucket()
|
154 |
+
|
155 |
+
if not os.path.exists(local_dir):
|
156 |
+
os.makedirs(local_dir)
|
157 |
+
elif not force_download:
|
158 |
+
return
|
159 |
+
|
160 |
+
download_files = [file_info.file_name.split(b2_dir + '/')[-1]
|
161 |
+
for file_info, _ in bucket.ls(b2_dir, show_versions=False)]
|
162 |
+
|
163 |
+
for file_name in download_files:
|
164 |
+
if file_name.endswith('/.bzEmpty'): # subdirectory, download recursively
|
165 |
+
subdir = file_name.replace('/.bzEmpty', '')
|
166 |
+
if len(subdir) > 0:
|
167 |
+
b2_subdir = os.path.join(b2_dir, subdir)
|
168 |
+
local_subdir = os.path.join(local_dir, subdir)
|
169 |
+
if b2_subdir != b2_dir:
|
170 |
+
b2_download_folder(b2_subdir, local_subdir, force_download=force_download,
|
171 |
+
mirror_folder=mirror_folder)
|
172 |
+
else: # file
|
173 |
+
b2_file = os.path.join(b2_dir, file_name)
|
174 |
+
local_file = os.path.join(local_dir, file_name)
|
175 |
+
if not os.path.exists(local_file) or force_download:
|
176 |
+
print(f"downloading b2://{b2_file} -> {local_file}")
|
177 |
+
bucket.download_file_by_name(b2_file, DownloadDestLocalFile(local_file))
|
178 |
+
|
179 |
+
if mirror_folder: # remove all files that are not on the b2 server anymore
|
180 |
+
for i, file in enumerate(download_files):
|
181 |
+
if file.endswith('/.bzEmpty'): # subdirectory, download recursively
|
182 |
+
download_files[i] = file.replace('/.bzEmpty', '')
|
183 |
+
for file_name in os.listdir(local_dir):
|
184 |
+
if file_name not in download_files:
|
185 |
+
local_file = os.path.join(local_dir, file_name)
|
186 |
+
print(f"deleting {local_file}")
|
187 |
+
if os.path.isdir(local_file):
|
188 |
+
shutil.rmtree(local_file)
|
189 |
+
else:
|
190 |
+
os.remove(local_file)
|
191 |
+
|
192 |
+
|
193 |
+
def get_name(obj):
|
194 |
+
return obj.__name__ if hasattr(obj, '__name__') else type(obj).__name__
|
195 |
+
|
196 |
+
|
197 |
+
def get_mlflow_model_by_name(experiment_name, run_name,
|
198 |
+
tracking_uri="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com",
|
199 |
+
download_model=True):
|
200 |
+
|
201 |
+
# 0. mlflow basics
|
202 |
+
mlflow.set_tracking_uri(tracking_uri)
|
203 |
+
os.environ["AWS_ACCESS_KEY_ID"] = "#TODO: add your AWS access key if you want to write your results to our collaborative lab server"
|
204 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = "#TODO: add your AWS seceret access key if you want to write your results to our collaborative lab server"
|
205 |
+
|
206 |
+
# # 1. use get_experiment_by_name to get experiment objec
|
207 |
+
experiment = mlflow.get_experiment_by_name(experiment_name)
|
208 |
+
|
209 |
+
# # 2. use search_runs with experiment_id for string search query
|
210 |
+
if os.path.isfile('cache/runs_names.pkl'):
|
211 |
+
runs = pd.read_pickle('cache/runs_names.pkl')
|
212 |
+
if runs['tags.mlflow.runName'][runs['tags.mlflow.runName'] == run_name].empty:
|
213 |
+
# returns a pandas data frame where each row is a run (if several exist under that name)
|
214 |
+
runs = fetch_runs_list_mlflow(experiment)
|
215 |
+
else:
|
216 |
+
# returns a pandas data frame where each row is a run (if several exist under that name)
|
217 |
+
runs = fetch_runs_list_mlflow(experiment)
|
218 |
+
|
219 |
+
# 3. get the selected run between all runs inside the selected experiment
|
220 |
+
run = runs.loc[runs['tags.mlflow.runName'] == run_name]
|
221 |
+
|
222 |
+
# 4. check if there is only a run with that name
|
223 |
+
assert len(run) == 1, "More runs with this name"
|
224 |
+
index_run = run.index[0]
|
225 |
+
artifact_uri = run.loc[index_run, 'artifact_uri']
|
226 |
+
|
227 |
+
# 5. load state_dict of your run
|
228 |
+
state_dict = mlflow.pytorch.load_state_dict(artifact_uri)
|
229 |
+
|
230 |
+
# 6. load model of your run
|
231 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
232 |
+
# model = mlflow.pytorch.load_model(os.path.join(
|
233 |
+
# artifact_uri, "model"), map_location=torch.device(DEVICE))
|
234 |
+
model = fetch_from_mlflow(os.path.join(
|
235 |
+
artifact_uri, "model"), use_cache=True, download_model=download_model)
|
236 |
+
|
237 |
+
return state_dict, model
|
238 |
+
|
239 |
+
|
240 |
+
def data_loader_mean_and_std(data_loader, transform=None):
|
241 |
+
means = []
|
242 |
+
stds = []
|
243 |
+
for x, y in data_loader:
|
244 |
+
if transform is not None:
|
245 |
+
x = transform(x)
|
246 |
+
means.append(x.mean(dim=(0, 2, 3)).unsqueeze(0))
|
247 |
+
stds.append(x.std(dim=(0, 2, 3)).unsqueeze(0))
|
248 |
+
return torch.cat(means).mean(dim=0), torch.cat(stds).mean(dim=0)
|
249 |
+
|
250 |
+
|
251 |
+
def fetch_runs_list_mlflow(experiment):
|
252 |
+
runs = mlflow.search_runs(experiment.experiment_id)
|
253 |
+
runs.to_pickle('cache/runs_names.pkl') # where to save it, usually as a .pkl
|
254 |
+
return runs
|
255 |
+
|
256 |
+
|
257 |
+
def fetch_from_mlflow(uri, use_cache=True, download_model=True):
|
258 |
+
cache_loc = os.path.join('cache', uri.split('//')[1]) + '.pt'
|
259 |
+
if use_cache and os.path.exists(cache_loc):
|
260 |
+
print(f'loading cached model from {cache_loc} ...')
|
261 |
+
return torch.load(cache_loc)
|
262 |
+
else:
|
263 |
+
print(f'fetching model from {uri} ...')
|
264 |
+
model = mlflow.pytorch.load_model(uri)
|
265 |
+
os.makedirs(os.path.dirname(cache_loc), exist_ok=True)
|
266 |
+
if download_model:
|
267 |
+
torch.save(model, cache_loc, pickle_module=mlflow.pytorch.pickle_module)
|
268 |
+
return model
|
269 |
+
|
270 |
+
|
271 |
+
def display_mlflow_run_info(run):
|
272 |
+
uri = mlflow.get_tracking_uri()
|
273 |
+
experiment_id = run.info.experiment_id
|
274 |
+
experiment_name = mlflow.get_experiment(experiment_id).name
|
275 |
+
run_id = run.info.run_id
|
276 |
+
run_name = run.data.tags['mlflow.runName']
|
277 |
+
experiment_url = f'{uri}/#/experiments/{experiment_id}'
|
278 |
+
run_url = f'{experiment_url}/runs/{run_id}'
|
279 |
+
|
280 |
+
print(f'view results at {run_url}')
|
281 |
+
display(Markdown(
|
282 |
+
f"[<a href='{experiment_url}'>experiment {experiment_id} '{experiment_name}'</a>]"
|
283 |
+
f" > "
|
284 |
+
f"[<a href='{run_url}'>run '{run_name}' {run_id}</a>]"
|
285 |
+
))
|
286 |
+
print('')
|
287 |
+
|
288 |
+
|
289 |
+
def get_train_test_indices_drone(df, frac, seed=None):
|
290 |
+
""" Split indices of a DataFrame with binary and balanced labels into balanced subindices
|
291 |
+
|
292 |
+
Args:
|
293 |
+
df (pd.DataFrame): {0,1}-labeled data
|
294 |
+
frac (float): fraction of indicies in first subset
|
295 |
+
random_seed (int): random seed used as random state in np.random and as argument for random.seed()
|
296 |
+
Returns:
|
297 |
+
train_indices (torch.tensor): balanced subset of indices corresponding to rows in the DataFrame
|
298 |
+
test_indices (torch.tensor): balanced subset of indices corresponding to rows in the DataFrame
|
299 |
+
"""
|
300 |
+
|
301 |
+
split_idx = int(len(df) * frac / 2)
|
302 |
+
df_with = df[df['label'] == 1]
|
303 |
+
df_without = df[df['label'] == 0]
|
304 |
+
|
305 |
+
np.random.seed(seed)
|
306 |
+
df_with_train = df_with.sample(n=split_idx, random_state=seed)
|
307 |
+
df_with_test = df_with.drop(df_with_train.index)
|
308 |
+
|
309 |
+
df_without_train = df_without.sample(n=split_idx, random_state=seed)
|
310 |
+
df_without_test = df_without.drop(df_without_train.index)
|
311 |
+
|
312 |
+
train_indices = list(df_without_train.index) + list(df_with_train.index)
|
313 |
+
test_indices = list(df_without_test.index) + list(df_with_test.index)
|
314 |
+
|
315 |
+
""""
|
316 |
+
print('fraction of 1-label in train set: {}'.format(len(df_with_train)/(len(df_with_train) + len(df_without_train))))
|
317 |
+
print('fraction of 1-label in test set: {}'.format(len(df_with_test)/(len(df_with_test) + len(df_with_test))))
|
318 |
+
"""
|
319 |
+
|
320 |
+
return train_indices, test_indices
|
321 |
+
|
322 |
+
|
323 |
+
def smp_get_loss(loss):
|
324 |
+
if loss == "Dice":
|
325 |
+
return smp.losses.DiceLoss(mode='binary', from_logits=True)
|
326 |
+
if loss == "BCE":
|
327 |
+
return nn.BCELoss()
|
328 |
+
elif loss == "BCEWithLogits":
|
329 |
+
return smp.losses.BCEWithLogitsLoss()
|
330 |
+
elif loss == "DicyBCE":
|
331 |
+
from pytorch_toolbelt import losses as ptbl
|
332 |
+
return ptbl.JointLoss(ptbl.DiceLoss(mode='binary', from_logits=False),
|
333 |
+
nn.BCELoss(),
|
334 |
+
first_weight=args.dice_weight,
|
335 |
+
second_weight=args.bce_weight)
|
utils/dataset_utils.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from skimage.util.shape import view_as_windows
|
8 |
+
|
9 |
+
|
10 |
+
def load_image(path):
|
11 |
+
file_type = path.split('.')[-1].lower()
|
12 |
+
if file_type == 'dng':
|
13 |
+
img = rawpy.imread(path).raw_image_visible
|
14 |
+
elif file_type == 'tiff' or file_type == 'tif':
|
15 |
+
img = np.array(tiff.imread(path), dtype=np.float32)
|
16 |
+
else:
|
17 |
+
img = np.array(Image.open(path), dtype=np.float32)
|
18 |
+
return img
|
19 |
+
|
20 |
+
|
21 |
+
def list_images_in_dir(path):
|
22 |
+
image_list = [os.path.join(path, img_name)
|
23 |
+
for img_name in sorted(os.listdir(path))
|
24 |
+
if img_name.split('.')[-1].lower() in IMAGE_FILE_TYPES]
|
25 |
+
return image_list
|
26 |
+
|
27 |
+
|
28 |
+
def k_fold(dataset, n_splits: int, seed: int, train_size: float):
|
29 |
+
"""Split dataset in subsets for cross-validation
|
30 |
+
|
31 |
+
Args:
|
32 |
+
dataset (class): dataset to split
|
33 |
+
n_split (int): Number of re-shuffling & splitting iterations.
|
34 |
+
seed (int): seed for k_fold splitting
|
35 |
+
train_size (float): should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the train split.
|
36 |
+
Returns:
|
37 |
+
idxs (list): indeces for splitting the dataset. The list contain n_split pair of train/test indeces.
|
38 |
+
"""
|
39 |
+
if hasattr(dataset, 'labels'):
|
40 |
+
x = dataset.images
|
41 |
+
y = dataset.labels
|
42 |
+
elif hasattr(dataset, 'masks'):
|
43 |
+
x = dataset.images
|
44 |
+
y = dataset.masks
|
45 |
+
|
46 |
+
idxs = []
|
47 |
+
|
48 |
+
if dataset.task == 'classification':
|
49 |
+
sss = StratifiedShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=seed)
|
50 |
+
|
51 |
+
for idxs_train, idxs_test in sss.split(x, y):
|
52 |
+
idxs.append((idxs_train.tolist(), idxs_test.tolist()))
|
53 |
+
|
54 |
+
elif dataset.task == 'segmentation':
|
55 |
+
for n in range(n_splits):
|
56 |
+
split_idx = int(len(dataset) * train_size)
|
57 |
+
indices = np.random.permutation(len(dataset))
|
58 |
+
idxs.append((indices[:split_idx].tolist(), indices[split_idx:].tolist()))
|
59 |
+
|
60 |
+
return idxs
|
61 |
+
|
62 |
+
|
63 |
+
def split_img(imgs, ROIs=(3, 3), step=(1, 1)):
|
64 |
+
"""Split the imgs in regions of size ROIs.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
imgs (ndarray): images which you want to split
|
68 |
+
ROIs (tuple): size of sub-regions splitted (ROIs=region of interests)
|
69 |
+
step (tuple): step path from one sub-region to the next one (in the x,y axis)
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
ndarray: splitted subimages.
|
73 |
+
The size is (x_num_subROIs*y_num_subROIs, **) where:
|
74 |
+
x_num_subROIs = ( imgs.shape[1]-int(ROIs[1]/2)*2 )/step[1]
|
75 |
+
y_num_subROIs = ( imgs.shape[0]-int(ROIs[0]/2)*2 )/step[0]
|
76 |
+
|
77 |
+
Example:
|
78 |
+
>>> from dataset_generator import split
|
79 |
+
>>> imgs_splitted = split(imgs, ROI_size = (5,5), step=(2,3))
|
80 |
+
"""
|
81 |
+
|
82 |
+
if len(ROIs) > 2:
|
83 |
+
return print("ROIs is a 2 element list")
|
84 |
+
|
85 |
+
if len(step) > 2:
|
86 |
+
return print("step is a 2 element list")
|
87 |
+
|
88 |
+
if type(imgs) != type(np.array(1)):
|
89 |
+
return print("imgs should be a ndarray")
|
90 |
+
|
91 |
+
if len(imgs.shape) == 2: # Single image with one channel (HxW)
|
92 |
+
splitted = view_as_windows(imgs, (ROIs[0], ROIs[1]), (step[0], step[1]))
|
93 |
+
return splitted.reshape((-1, ROIs[0], ROIs[1]))
|
94 |
+
|
95 |
+
if len(imgs.shape) == 3:
|
96 |
+
_, _, channels = imgs.shape
|
97 |
+
if channels <= 3: # Single image more channels (HxWxC)
|
98 |
+
splitted = view_as_windows(imgs, (ROIs[0], ROIs[1], channels), (step[0], step[1], channels))
|
99 |
+
return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
|
100 |
+
else: # More images with 1 channel
|
101 |
+
splitted = view_as_windows(imgs, (1, ROIs[0], ROIs[1]), (1, step[0], step[1]))
|
102 |
+
return splitted.reshape((-1, ROIs[0], ROIs[1]))
|
103 |
+
|
104 |
+
if len(imgs.shape) == 4: # More images with more channels(BxHxWxC)
|
105 |
+
_, _, _, channels = imgs.shape
|
106 |
+
splitted = view_as_windows(imgs, (1, ROIs[0], ROIs[1], channels), (1, step[0], step[1], channels))
|
107 |
+
return splitted.reshape((-1, ROIs[0], ROIs[1], channels))
|
108 |
+
|
109 |
+
|
110 |
+
def join_blocks(splitted, final_shape):
|
111 |
+
"""Join blocks to reobtain a splitted image
|
112 |
+
|
113 |
+
Attribute:
|
114 |
+
splitted (tensor) = image splitted in blocks, size = (N_blocks, Channels, Height, Width)
|
115 |
+
final_shape (tuple) = size of the final image reconstructed (Height, Width)
|
116 |
+
Return:
|
117 |
+
tensor: image restored from blocks. size = (Channels, Height, Width)
|
118 |
+
|
119 |
+
"""
|
120 |
+
n_blocks, channels, ROI_height, ROI_width = splitted.shape
|
121 |
+
|
122 |
+
rows = final_shape[0] // ROI_height
|
123 |
+
columns = final_shape[1] // ROI_width
|
124 |
+
|
125 |
+
final_img = torch.empty(rows, channels, ROI_height, ROI_width * columns)
|
126 |
+
for r in range(rows):
|
127 |
+
stackblocks = splitted[r * columns]
|
128 |
+
for c in range(1, columns):
|
129 |
+
stackblocks = torch.cat((stackblocks, splitted[r * columns + c]), axis=2)
|
130 |
+
final_img[r] = stackblocks
|
131 |
+
|
132 |
+
joined_img = final_img[0]
|
133 |
+
|
134 |
+
for i in np.arange(1, len(final_img)):
|
135 |
+
joined_img = torch.cat((joined_img, final_img[i]), axis=1)
|
136 |
+
|
137 |
+
return joined_img
|
138 |
+
|
139 |
+
|
140 |
+
def random_ROI(X, Y, ROIs=(512, 512)):
|
141 |
+
""" Return a random region for each input/target pair images of the dataset
|
142 |
+
Args:
|
143 |
+
Y (ndarray): target of your dataset --> size: (BxHxWxC)
|
144 |
+
X (ndarray): input of your dataset --> size: (BxHxWxC)
|
145 |
+
ROIs (tuple): size of random region (ROIs=region of interests)
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
For each pair images (input/target) of the dataset, return respectively random ROIs
|
149 |
+
Y_cut (ndarray): target of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
|
150 |
+
X_cut (ndarray): input of your dataset --> size: (Batch,Channels,ROIs[0],ROIs[1])
|
151 |
+
|
152 |
+
Example:
|
153 |
+
>>> from dataset_generator import random_ROI
|
154 |
+
>>> X,Y = random_ROI(X,Y, ROIs = (10,10))
|
155 |
+
"""
|
156 |
+
|
157 |
+
batch, channels, height, width = X.shape
|
158 |
+
|
159 |
+
X_cut = np.empty((batch, ROIs[0], ROIs[1], channels))
|
160 |
+
Y_cut = np.empty((batch, ROIs[0], ROIs[1], channels))
|
161 |
+
|
162 |
+
for i in np.arange(len(X)):
|
163 |
+
x_size = int(random.random() * (height - (ROIs[0] + 1)))
|
164 |
+
y_size = int(random.random() * (width - (ROIs[1] + 1)))
|
165 |
+
X_cut[i] = X[i, x_size:x_size + ROIs[0], y_size:y_size + ROIs[1], :]
|
166 |
+
Y_cut[i] = Y[i, x_size:x_size + ROIs[0], y_size:y_size + ROIs[1], :]
|
167 |
+
return X_cut, Y_cut
|
168 |
+
|
169 |
+
|
170 |
+
def one2many_random_ROI(X, Y, datasize=1000, ROIs=(512, 512)):
|
171 |
+
""" Return a dataset of N subimages obtained from random regions of the same image
|
172 |
+
Args:
|
173 |
+
Y (ndarray): target of your dataset --> size: (1,H,W,C)
|
174 |
+
X (ndarray): input of your dataset --> size: (1,H,W,C)
|
175 |
+
datasize = number of random ROIs to generate
|
176 |
+
ROIs (tuple): size of random region (ROIs=region of interests)
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
Y_cut (ndarray): target of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
|
180 |
+
X_cut (ndarray): input of your dataset --> size: (Datasize,ROIs[0],ROIs[1],Channels)
|
181 |
+
"""
|
182 |
+
|
183 |
+
batch, channels, height, width = X.shape
|
184 |
+
|
185 |
+
X_cut = np.empty((datasize, ROIs[0], ROIs[1], channels))
|
186 |
+
Y_cut = np.empty((datasize, ROIs[0], ROIs[1], channels))
|
187 |
+
|
188 |
+
for i in np.arange(datasize):
|
189 |
+
X_cut[i], Y_cut[i] = random_ROI(X, Y, ROIs)
|
190 |
+
return X_cut, Y_cut
|
utils/hendrycks_robustness.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Code extracted from the paper:
|
3 |
+
|
4 |
+
@articlehendrycks2019robustness,
|
5 |
+
title=Benchmarking Neural Network Robustness to Common Corruptions and Perturbations,
|
6 |
+
author=Dan Hendrycks and Thomas Dietterich,
|
7 |
+
journal=Proceedings of the International Conference on Learning Representations,
|
8 |
+
year=2019
|
9 |
+
}
|
10 |
+
|
11 |
+
The code is modified to fit with our model
|
12 |
+
'''
|
13 |
+
|
14 |
+
import os
|
15 |
+
from PIL import Image
|
16 |
+
import os.path
|
17 |
+
import time
|
18 |
+
import torch
|
19 |
+
import torchvision.datasets as dset
|
20 |
+
import torchvision.transforms as trn
|
21 |
+
import torch.utils.data as data
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from PIL import Image
|
25 |
+
|
26 |
+
|
27 |
+
# /////////////// Distortion Helpers ///////////////
|
28 |
+
|
29 |
+
import skimage as sk
|
30 |
+
from skimage.filters import gaussian
|
31 |
+
from io import BytesIO
|
32 |
+
from wand.image import Image as WandImage
|
33 |
+
from wand.api import library as wandlibrary
|
34 |
+
import wand.color as WandColor
|
35 |
+
import ctypes
|
36 |
+
from PIL import Image as PILImage
|
37 |
+
import cv2
|
38 |
+
from scipy.ndimage import zoom as scizoom
|
39 |
+
from scipy.ndimage.interpolation import map_coordinates
|
40 |
+
import warnings
|
41 |
+
|
42 |
+
warnings.simplefilter("ignore", UserWarning)
|
43 |
+
|
44 |
+
|
45 |
+
def disk(radius, alias_blur=0.1, dtype=np.float32):
|
46 |
+
if radius <= 8:
|
47 |
+
L = np.arange(-8, 8 + 1)
|
48 |
+
ksize = (3, 3)
|
49 |
+
else:
|
50 |
+
L = np.arange(-radius, radius + 1)
|
51 |
+
ksize = (5, 5)
|
52 |
+
X, Y = np.meshgrid(L, L)
|
53 |
+
aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype)
|
54 |
+
aliased_disk /= np.sum(aliased_disk)
|
55 |
+
|
56 |
+
# supersample disk to antialias
|
57 |
+
return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
|
58 |
+
|
59 |
+
|
60 |
+
# Tell Python about the C method
|
61 |
+
wandlibrary.MagickMotionBlurImage.argtypes = (ctypes.c_void_p, # wand
|
62 |
+
ctypes.c_double, # radius
|
63 |
+
ctypes.c_double, # sigma
|
64 |
+
ctypes.c_double) # angle
|
65 |
+
|
66 |
+
|
67 |
+
# Extend wand.image.Image class to include method signature
|
68 |
+
class MotionImage(WandImage):
|
69 |
+
def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
|
70 |
+
wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
|
71 |
+
|
72 |
+
|
73 |
+
# modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py
|
74 |
+
def plasma_fractal(mapsize=32, wibbledecay=3):
|
75 |
+
"""
|
76 |
+
Generate a heightmap using diamond-square algorithm.
|
77 |
+
Return square 2d array, side length 'mapsize', of floats in range 0-255.
|
78 |
+
'mapsize' must be a power of two.
|
79 |
+
"""
|
80 |
+
assert (mapsize & (mapsize - 1) == 0)
|
81 |
+
maparray = np.empty((mapsize, mapsize), dtype=np.float_)
|
82 |
+
maparray[0, 0] = 0
|
83 |
+
stepsize = mapsize
|
84 |
+
wibble = 100
|
85 |
+
|
86 |
+
def wibbledmean(array):
|
87 |
+
return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape)
|
88 |
+
|
89 |
+
def fillsquares():
|
90 |
+
"""For each square of points stepsize apart,
|
91 |
+
calculate middle value as mean of points + wibble"""
|
92 |
+
cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
|
93 |
+
squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
|
94 |
+
squareaccum += np.roll(squareaccum, shift=-1, axis=1)
|
95 |
+
maparray[stepsize // 2:mapsize:stepsize,
|
96 |
+
stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
|
97 |
+
|
98 |
+
def filldiamonds():
|
99 |
+
"""For each diamond of points stepsize apart,
|
100 |
+
calculate middle value as mean of points + wibble"""
|
101 |
+
mapsize = maparray.shape[0]
|
102 |
+
drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
|
103 |
+
ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
|
104 |
+
ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
|
105 |
+
lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
|
106 |
+
ltsum = ldrsum + lulsum
|
107 |
+
maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
|
108 |
+
tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
|
109 |
+
tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
|
110 |
+
ttsum = tdrsum + tulsum
|
111 |
+
maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)
|
112 |
+
|
113 |
+
while stepsize >= 2:
|
114 |
+
fillsquares()
|
115 |
+
filldiamonds()
|
116 |
+
stepsize //= 2
|
117 |
+
wibble /= wibbledecay
|
118 |
+
|
119 |
+
maparray -= maparray.min()
|
120 |
+
return maparray / maparray.max()
|
121 |
+
|
122 |
+
|
123 |
+
def clipped_zoom(img, zoom_factor):
|
124 |
+
h = img.shape[0]
|
125 |
+
# ceil crop height(= crop width)
|
126 |
+
ch = int(np.ceil(h / zoom_factor))
|
127 |
+
|
128 |
+
top = (h - ch) // 2
|
129 |
+
img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
|
130 |
+
# trim off any extra pixels
|
131 |
+
trim_top = (img.shape[0] - h) // 2
|
132 |
+
|
133 |
+
return img[trim_top:trim_top + h, trim_top:trim_top + h]
|
134 |
+
|
135 |
+
|
136 |
+
# /////////////// End Distortion Helpers ///////////////
|
137 |
+
|
138 |
+
|
139 |
+
# /////////////// Distortions ///////////////
|
140 |
+
|
141 |
+
class Distortions:
|
142 |
+
def __init__(self, severity=1, transform='identity'):
|
143 |
+
self.severity = severity
|
144 |
+
self.transform = transform
|
145 |
+
|
146 |
+
def __call__(self, img):
|
147 |
+
assert torch.is_tensor(img), 'Input data need to be a torch.tensor'
|
148 |
+
assert len(img.shape) == 3, 'Input image should be RGB'
|
149 |
+
img = self.torch2np(img)
|
150 |
+
t = getattr(self, self.transform)
|
151 |
+
img = t(img, self.severity)
|
152 |
+
return self.np2torch(img).float()
|
153 |
+
|
154 |
+
def np2torch(self,x):
|
155 |
+
return torch.tensor(x).permute(2,0,1)
|
156 |
+
|
157 |
+
def torch2np(self,x):
|
158 |
+
return np.array(x.permute(1,2,0))
|
159 |
+
|
160 |
+
def identity(self,x, severity=1):
|
161 |
+
return x
|
162 |
+
|
163 |
+
def gaussian_noise(self, x, severity=1):
|
164 |
+
c = [0.04, 0.06, .08, .09, .10][severity - 1]
|
165 |
+
return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1)
|
166 |
+
|
167 |
+
|
168 |
+
def shot_noise(self, x, severity=1):
|
169 |
+
c = [500, 250, 100, 75, 50][severity - 1]
|
170 |
+
return np.clip(np.random.poisson(x * c) / c, 0, 1)
|
171 |
+
|
172 |
+
|
173 |
+
def impulse_noise(self, x, severity=1):
|
174 |
+
c = [.01, .02, .03, .05, .07][severity - 1]
|
175 |
+
|
176 |
+
x = sk.util.random_noise(x, mode='s&p', amount=c)
|
177 |
+
return np.clip(x, 0, 1)
|
178 |
+
|
179 |
+
|
180 |
+
def speckle_noise(self, x, severity=1):
|
181 |
+
c = [.06, .1, .12, .16, .2][severity - 1]
|
182 |
+
return np.clip(x + x * np.random.normal(size=x.shape, scale=c), 0, 1)
|
183 |
+
|
184 |
+
|
185 |
+
def gaussian_blur(self, x, severity=1):
|
186 |
+
c = [.4, .6, 0.7, .8, 1][severity - 1]
|
187 |
+
|
188 |
+
x = gaussian(x, sigma=c, multichannel=True)
|
189 |
+
return np.clip(x, 0, 1)
|
190 |
+
|
191 |
+
|
192 |
+
def glass_blur(self, x, severity=1):
|
193 |
+
# sigma, max_delta, iterations
|
194 |
+
c = [(0.05,1,1), (0.25,1,1), (0.4,1,1), (0.25,1,2), (0.4,1,2)][severity - 1]
|
195 |
+
|
196 |
+
x = gaussian(x, sigma=c[0], multichannel=True)
|
197 |
+
|
198 |
+
# locally shuffle pixels
|
199 |
+
for i in range(c[2]):
|
200 |
+
for h in range(32 - c[1], c[1], -1):
|
201 |
+
for w in range(32 - c[1], c[1], -1):
|
202 |
+
dx, dy = np.random.randint(-c[1], c[1], size=(2,))
|
203 |
+
h_prime, w_prime = h + dy, w + dx
|
204 |
+
# swap
|
205 |
+
x[h, w], x[h_prime, w_prime] = x[h_prime, w_prime], x[h, w]
|
206 |
+
|
207 |
+
return np.clip(gaussian(x, sigma=c[0], multichannel=True), 0, 1)
|
208 |
+
|
209 |
+
|
210 |
+
def defocus_blur(self, x, severity=1):
|
211 |
+
c = [(0.3, 0.4), (0.4, 0.5), (0.5, 0.6), (1, 0.2), (1.5, 0.1)][severity - 1]
|
212 |
+
kernel = disk(radius=c[0], alias_blur=c[1])
|
213 |
+
|
214 |
+
channels = []
|
215 |
+
for d in range(3):
|
216 |
+
channels.append(cv2.filter2D(x[:, :, d], -1, kernel))
|
217 |
+
channels = np.array(channels).transpose((1, 2, 0)) # 3x32x32 -> 32x32x3
|
218 |
+
|
219 |
+
return np.clip(channels, 0, 1)
|
220 |
+
|
221 |
+
|
222 |
+
def motion_blur(self, x, severity=1):
|
223 |
+
c = [(6,1), (6,1.5), (6,2), (8,2), (9,2.5)][severity - 1]
|
224 |
+
|
225 |
+
output = BytesIO()
|
226 |
+
x.save(output, format='PNG')
|
227 |
+
x = MotionImage(blob=output.getvalue())
|
228 |
+
|
229 |
+
x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))
|
230 |
+
|
231 |
+
x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8),
|
232 |
+
cv2.IMREAD_UNCHANGED)
|
233 |
+
|
234 |
+
if x.shape != (32, 32):
|
235 |
+
return np.clip(x[..., [2, 1, 0]], 0, 1) # BGR to RGB
|
236 |
+
else: # greyscale to RGB
|
237 |
+
return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 1)
|
238 |
+
|
239 |
+
|
240 |
+
def zoom_blur(self, x, severity=1):
|
241 |
+
c = [np.arange(1, 1.06, 0.01), np.arange(1, 1.11, 0.01), np.arange(1, 1.16, 0.01),
|
242 |
+
np.arange(1, 1.21, 0.01), np.arange(1, 1.26, 0.01)][severity - 1]
|
243 |
+
out = np.zeros_like(x)
|
244 |
+
for zoom_factor in c:
|
245 |
+
out += clipped_zoom(x, zoom_factor)
|
246 |
+
|
247 |
+
x = (x + out) / (len(c) + 1)
|
248 |
+
return np.clip(x, 0, 1)
|
249 |
+
|
250 |
+
|
251 |
+
def fog(self, x, severity=1):
|
252 |
+
c = [(.2,3), (.5,3), (0.75,2.5), (1,2), (1.5,1.75)][severity - 1]
|
253 |
+
max_val = x.max()
|
254 |
+
x += c[0] * plasma_fractal(wibbledecay=c[1])[:32, :32][..., np.newaxis]
|
255 |
+
return np.clip(x * max_val / (max_val + c[0]), 0, 1)
|
256 |
+
|
257 |
+
|
258 |
+
def frost(self, x, severity=1):
|
259 |
+
c = [(1, 0.2), (1, 0.3), (0.9, 0.4), (0.85, 0.4), (0.75, 0.45)][severity - 1]
|
260 |
+
idx = np.random.randint(5)
|
261 |
+
filename = ['./frost1.png', './frost2.png', './frost3.png', './frost4.jpg', './frost5.jpg', './frost6.jpg'][idx]
|
262 |
+
frost = cv2.imread(filename)
|
263 |
+
frost = cv2.resize(frost, (0, 0), fx=0.2, fy=0.2)
|
264 |
+
# randomly crop and convert to rgb
|
265 |
+
x_start, y_start = np.random.randint(0, frost.shape[0] - 32), np.random.randint(0, frost.shape[1] - 32)
|
266 |
+
frost = frost[x_start:x_start + 32, y_start:y_start + 32][..., [2, 1, 0]]
|
267 |
+
|
268 |
+
return np.clip(c[0] * np.array(x) + c[1] * frost, 0, 1)
|
269 |
+
|
270 |
+
|
271 |
+
def snow(self, x, severity=1):
|
272 |
+
c = [(0.1,0.2,1,0.6,8,3,0.95),
|
273 |
+
(0.1,0.2,1,0.5,10,4,0.9),
|
274 |
+
(0.15,0.3,1.75,0.55,10,4,0.9),
|
275 |
+
(0.25,0.3,2.25,0.6,12,6,0.85),
|
276 |
+
(0.3,0.3,1.25,0.65,14,12,0.8)][severity - 1]
|
277 |
+
|
278 |
+
snow_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1]) # [:2] for monochrome
|
279 |
+
|
280 |
+
snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
|
281 |
+
snow_layer[snow_layer < c[3]] = 0
|
282 |
+
|
283 |
+
snow_layer = PILImage.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
|
284 |
+
output = BytesIO()
|
285 |
+
snow_layer.save(output, format='PNG')
|
286 |
+
snow_layer = MotionImage(blob=output.getvalue())
|
287 |
+
|
288 |
+
snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=np.random.uniform(-135, -45))
|
289 |
+
|
290 |
+
snow_layer = cv2.imdecode(np.fromstring(snow_layer.make_blob(), np.uint8),
|
291 |
+
cv2.IMREAD_UNCHANGED) / (2**16-1)
|
292 |
+
snow_layer = snow_layer[..., np.newaxis]
|
293 |
+
|
294 |
+
x = c[6] * x + (1 - c[6]) * np.maximum(x, cv2.cvtColor(x, cv2.COLOR_RGB2GRAY).reshape(32, 32, 1) * 1.5 + 0.5)
|
295 |
+
return np.clip(x + snow_layer + np.rot90(snow_layer, k=2), 0, 1)
|
296 |
+
|
297 |
+
|
298 |
+
def spatter(self, x, severity=1):
|
299 |
+
c = [(0.62,0.1,0.7,0.7,0.5,0),
|
300 |
+
(0.65,0.1,0.8,0.7,0.5,0),
|
301 |
+
(0.65,0.3,1,0.69,0.5,0),
|
302 |
+
(0.65,0.1,0.7,0.69,0.6,1),
|
303 |
+
(0.65,0.1,0.5,0.68,0.6,1)][severity - 1]
|
304 |
+
|
305 |
+
liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])
|
306 |
+
|
307 |
+
liquid_layer = gaussian(liquid_layer, sigma=c[2])
|
308 |
+
liquid_layer[liquid_layer < c[3]] = 0
|
309 |
+
if c[5] == 0:
|
310 |
+
liquid_layer = (liquid_layer * (2**16-1)).astype(np.uint8)
|
311 |
+
dist = (2**16-1) - cv2.Canny(liquid_layer, 50, 150)
|
312 |
+
dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
|
313 |
+
_, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
|
314 |
+
dist = cv2.blur(dist, (3, 3)).astype(np.uint8)
|
315 |
+
dist = cv2.equalizeHist(dist)
|
316 |
+
# ker = np.array([[-1,-2,-3],[-2,0,0],[-3,0,1]], dtype=np.float32)
|
317 |
+
# ker -= np.mean(ker)
|
318 |
+
ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
|
319 |
+
dist = cv2.filter2D(dist, cv2.CV_8U, ker)
|
320 |
+
dist = cv2.blur(dist, (3, 3)).astype(np.float32)
|
321 |
+
|
322 |
+
m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)
|
323 |
+
m /= np.max(m, axis=(0, 1))
|
324 |
+
m *= c[4]
|
325 |
+
|
326 |
+
# water is pale turqouise
|
327 |
+
color = np.concatenate((175 / 255. * np.ones_like(m[..., :1]),
|
328 |
+
238 / 255. * np.ones_like(m[..., :1]),
|
329 |
+
238 / 255. * np.ones_like(m[..., :1])), axis=2)
|
330 |
+
|
331 |
+
color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA)
|
332 |
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA)
|
333 |
+
|
334 |
+
return cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * (2**16-1)
|
335 |
+
else:
|
336 |
+
m = np.where(liquid_layer > c[3], 1, 0)
|
337 |
+
m = gaussian(m.astype(np.float32), sigma=c[4])
|
338 |
+
m[m < 0.8] = 0
|
339 |
+
# m = np.abs(m) ** (1/c[4])
|
340 |
+
|
341 |
+
# mud brown
|
342 |
+
color = np.concatenate((63 / 255. * np.ones_like(x[..., :1]),
|
343 |
+
42 / 255. * np.ones_like(x[..., :1]),
|
344 |
+
20 / 255. * np.ones_like(x[..., :1])), axis=2)
|
345 |
+
|
346 |
+
color *= m[..., np.newaxis]
|
347 |
+
x *= (1 - m[..., np.newaxis])
|
348 |
+
|
349 |
+
return np.clip(x + color, 0, 1)
|
350 |
+
|
351 |
+
|
352 |
+
def contrast(self, x, severity=1):
|
353 |
+
c = [.75, .5, .4, .3, 0.15][severity - 1]
|
354 |
+
means = np.mean(x, axis=(0, 1), keepdims=True)
|
355 |
+
return np.clip((x - means) * c + means, 0, 1)
|
356 |
+
|
357 |
+
|
358 |
+
def brightness(self, x, severity=1):
|
359 |
+
c = [.05, .1, .15, .2, .3][severity - 1]
|
360 |
+
|
361 |
+
x = sk.color.rgb2hsv(x)
|
362 |
+
x[:, :, 2] = np.clip(x[:, :, 2] + c, 0, 1)
|
363 |
+
x = sk.color.hsv2rgb(x)
|
364 |
+
|
365 |
+
return np.clip(x, 0, 1)
|
366 |
+
|
367 |
+
|
368 |
+
def saturate(self, x, severity=1):
|
369 |
+
c = [(0.3, 0), (0.1, 0), (1.5, 0), (2, 0.1), (2.5, 0.2)][severity - 1]
|
370 |
+
|
371 |
+
x = sk.color.rgb2hsv(x)
|
372 |
+
x[:, :, 1] = np.clip(x[:, :, 1] * c[0] + c[1], 0, 1)
|
373 |
+
x = sk.color.hsv2rgb(x)
|
374 |
+
|
375 |
+
return np.clip(x, 0, 1)
|
376 |
+
|
377 |
+
|
378 |
+
def jpeg_compression(self, x, severity=1):
|
379 |
+
c = [80, 65, 58, 50, 40][severity - 1]
|
380 |
+
|
381 |
+
output = BytesIO()
|
382 |
+
x.save(output, 'JPEG', quality=c)
|
383 |
+
x = PILImage.open(output)
|
384 |
+
|
385 |
+
return x
|
386 |
+
|
387 |
+
|
388 |
+
def pixelate(self, x, severity=1):
|
389 |
+
c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
|
390 |
+
|
391 |
+
x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
|
392 |
+
x = x.resize((32, 32), PILImage.BOX)
|
393 |
+
|
394 |
+
return x
|
395 |
+
|
396 |
+
|
397 |
+
# mod of https://gist.github.com/erniejunior/601cdf56d2b424757de5
|
398 |
+
def elastic_transform(self, image, severity=1):
|
399 |
+
IMSIZE = 32
|
400 |
+
c = [(IMSIZE*0, IMSIZE*0, IMSIZE*0.08),
|
401 |
+
(IMSIZE*0.05, IMSIZE*0.2, IMSIZE*0.07),
|
402 |
+
(IMSIZE*0.08, IMSIZE*0.06, IMSIZE*0.06),
|
403 |
+
(IMSIZE*0.1, IMSIZE*0.04, IMSIZE*0.05),
|
404 |
+
(IMSIZE*0.1, IMSIZE*0.03, IMSIZE*0.03)][severity - 1]
|
405 |
+
|
406 |
+
shape = image.shape
|
407 |
+
shape_size = shape[:2]
|
408 |
+
|
409 |
+
# random affine
|
410 |
+
center_square = np.float32(shape_size) // 2
|
411 |
+
square_size = min(shape_size) // 3
|
412 |
+
pts1 = np.float32([center_square + square_size,
|
413 |
+
[center_square[0] + square_size, center_square[1] - square_size],
|
414 |
+
center_square - square_size])
|
415 |
+
pts2 = pts1 + np.random.uniform(-c[2], c[2], size=pts1.shape).astype(np.float32)
|
416 |
+
M = cv2.getAffineTransform(pts1, pts2)
|
417 |
+
image = cv2.warpAffine(image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101)
|
418 |
+
|
419 |
+
dx = (gaussian(np.random.uniform(-1, 1, size=shape[:2]),
|
420 |
+
c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32)
|
421 |
+
dy = (gaussian(np.random.uniform(-1, 1, size=shape[:2]),
|
422 |
+
c[1], mode='reflect', truncate=3) * c[0]).astype(np.float32)
|
423 |
+
dx, dy = dx[..., np.newaxis], dy[..., np.newaxis]
|
424 |
+
|
425 |
+
x, y, z = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2]))
|
426 |
+
indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z, (-1, 1))
|
427 |
+
return np.clip(map_coordinates(image, indices, order=1, mode='reflect').reshape(shape), 0, 1)
|
428 |
+
|
429 |
+
if __name__=='__main__':
|
430 |
+
import os
|
431 |
+
|
432 |
+
import numpy as np
|
433 |
+
import matplotlib.pyplot as plt
|
434 |
+
import tifffile as tiff
|
435 |
+
import torch
|
436 |
+
|
437 |
+
os.system('cd ..')
|
438 |
+
|
439 |
+
img = tiff.imread('/home/marco/perturbed-minds/perturbed-minds/data/microscopy/images/rgb_scale100/Ma190c_lame1_zone1_composite_Mcropped_1.tiff')
|
440 |
+
img = np.array(img)/(2**16-1)
|
441 |
+
img = torch.tensor(img).permute(2,0,1)
|
442 |
+
|
443 |
+
def identity(x, sev):
|
444 |
+
return x
|
445 |
+
|
446 |
+
if not os.path.exists('results/Cimages'):
|
447 |
+
os.makedirs('results/Cimages')
|
448 |
+
|
449 |
+
transformations = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'speckle_noise',
|
450 |
+
'gaussian_blur', 'zoom_blur', 'contrast', 'brightness', 'saturate', 'elastic_transform']
|
451 |
+
|
452 |
+
# glass_blur, defocus_blur, motion_blur, fog, frost, snow, spatter, jpeg_compression, pixelate,
|
453 |
+
|
454 |
+
plt.figure()
|
455 |
+
plt.imshow(img.permute(1,2,0))
|
456 |
+
plt.title('identity')
|
457 |
+
plt.show()
|
458 |
+
plt.savefig(f'results/Cimages/1_identity.png')
|
459 |
+
|
460 |
+
|
461 |
+
for i,t in enumerate(transformations):
|
462 |
+
|
463 |
+
fig = plt.figure(figsize=(25,5))
|
464 |
+
columns = 5
|
465 |
+
rows = 1
|
466 |
+
|
467 |
+
for sev in range(1,6):
|
468 |
+
dist = Distortions(severity=sev, transform=t)
|
469 |
+
fig.add_subplot(rows, columns, sev)
|
470 |
+
plt.imshow(dist(img).permute(1,2,0))
|
471 |
+
plt.title(f'{t} {sev}')
|
472 |
+
plt.xticks([], [])
|
473 |
+
plt.yticks([], [])
|
474 |
+
plt.show()
|
475 |
+
plt.savefig(f'results/Cimages/{i+2}_{t}.png')
|
utils/ssim.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""https://github.com/Po-Hsun-Su/pytorch-ssim"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.autograd import Variable
|
6 |
+
import numpy as np
|
7 |
+
from math import exp
|
8 |
+
|
9 |
+
def gaussian(window_size, sigma):
|
10 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
11 |
+
return gauss/gauss.sum()
|
12 |
+
|
13 |
+
def create_window(window_size, channel):
|
14 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
15 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
16 |
+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
17 |
+
return window
|
18 |
+
|
19 |
+
def _ssim(img1, img2, window, window_size, channel, size_average = True):
|
20 |
+
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
|
21 |
+
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
|
22 |
+
|
23 |
+
mu1_sq = mu1.pow(2)
|
24 |
+
mu2_sq = mu2.pow(2)
|
25 |
+
mu1_mu2 = mu1*mu2
|
26 |
+
|
27 |
+
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
|
28 |
+
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
|
29 |
+
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
|
30 |
+
|
31 |
+
C1 = 0.01**2
|
32 |
+
C2 = 0.03**2
|
33 |
+
|
34 |
+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
35 |
+
|
36 |
+
if size_average:
|
37 |
+
return ssim_map.mean()
|
38 |
+
else:
|
39 |
+
return ssim_map.mean(1).mean(1).mean(1)
|
40 |
+
|
41 |
+
class SSIM(torch.nn.Module):
|
42 |
+
def __init__(self, window_size = 11, size_average = True):
|
43 |
+
super(SSIM, self).__init__()
|
44 |
+
self.window_size = window_size
|
45 |
+
self.size_average = size_average
|
46 |
+
self.channel = 1
|
47 |
+
self.window = create_window(window_size, self.channel)
|
48 |
+
|
49 |
+
def forward(self, img1, img2):
|
50 |
+
(_, channel, _, _) = img1.size()
|
51 |
+
|
52 |
+
if channel == self.channel and self.window.data.type() == img1.data.type():
|
53 |
+
window = self.window
|
54 |
+
else:
|
55 |
+
window = create_window(self.window_size, channel)
|
56 |
+
|
57 |
+
if img1.is_cuda:
|
58 |
+
window = window.cuda(img1.get_device())
|
59 |
+
window = window.type_as(img1)
|
60 |
+
|
61 |
+
self.window = window
|
62 |
+
self.channel = channel
|
63 |
+
|
64 |
+
|
65 |
+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
66 |
+
|
67 |
+
def ssim(img1, img2, window_size = 11, size_average = True):
|
68 |
+
(_, channel, _, _) = img1.size()
|
69 |
+
window = create_window(window_size, channel)
|
70 |
+
|
71 |
+
if img1.is_cuda:
|
72 |
+
window = window.cuda(img1.get_device())
|
73 |
+
window = window.type_as(img1)
|
74 |
+
|
75 |
+
return _ssim(img1, img2, window, window_size, channel, size_average)
|