File size: 5,015 Bytes
9b896f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from pathlib import Path

import torch
import torchvision
from torchvision.transforms import transforms, TrivialAugmentWide

from configs.dataset_params import normalize_params
from dataset_classes.cub200 import CUB200Class
from dataset_classes.stanfordcars import StanfordCarsClass
from dataset_classes.travelingbirds import TravelingBirds


def get_data(dataset, crop = True, img_size=448):
    batchsize = 16
    if dataset == "CUB2011":
        train_transform = get_augmentation(0.1, img_size, True,not crop, True, True, normalize_params["CUB2011"])
        test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["CUB2011"])
        train_dataset = CUB200Class(True, train_transform, crop)
        test_dataset = CUB200Class(False, test_transform, crop)
    elif dataset == "TravelingBirds":
        train_transform = get_augmentation(0.1, img_size, True, not crop, True, True, normalize_params["TravelingBirds"])
        test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["TravelingBirds"])
        train_dataset = TravelingBirds(True, train_transform, crop)
        test_dataset = TravelingBirds(False, test_transform, crop)

    elif dataset == "StanfordCars":
        train_transform = get_augmentation(0.1, img_size, True, True, True, True, normalize_params["StanfordCars"])
        test_transform = get_augmentation(0.1, img_size, False, True, True, True, normalize_params["StanfordCars"])
        train_dataset = StanfordCarsClass(True, train_transform)
        test_dataset = StanfordCarsClass(False, test_transform)
    elif dataset == "FGVCAircraft":
        raise NotImplementedError

    elif dataset == "ImageNet":
        # Defaults from the robustness package
        if img_size != 224:
            raise NotImplementedError("ImageNet is setup to only work with 224x224 images")
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.1,
                contrast=0.1,
                saturation=0.1
            ),
            transforms.ToTensor(),
            Lighting(0.05, IMAGENET_PCA['eigval'],
                     IMAGENET_PCA['eigvec'])
        ])
        """
        Standard training data augmentation for ImageNet-scale datasets: Random crop,
        Random flip, Color Jitter, and Lighting Transform (see https://git.io/fhBOc)
        """
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
        imgnet_root = Path.home()/ "tmp" /"Datasets"/ "imagenet"
        train_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='train',  transform=train_transform)
        test_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='val',  transform=test_transform)
        batchsize = 64

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8)
    return train_loader, test_loader

def get_augmentation(jitter,  size,  training,  random_center_crop, trivialAug, hflip, normalize):
    augmentation = []
    if random_center_crop:
        augmentation.append(transforms.Resize(size))
    else:
        augmentation.append(transforms.Resize((size, size)))
    if training:
        if random_center_crop:
                augmentation.append(transforms.RandomCrop(size, padding=4))
    else:
        if random_center_crop:
            augmentation.append(transforms.CenterCrop(size))
    if training:
        if hflip:
            augmentation.append(transforms.RandomHorizontalFlip())
        if jitter:
            augmentation.append(transforms.ColorJitter(jitter, jitter, jitter))
        if trivialAug:
            augmentation.append(TrivialAugmentWide())
    augmentation.append(transforms.ToTensor())
    augmentation.append(transforms.Normalize(**normalize))
    return transforms.Compose(augmentation)

class Lighting(object):
    """
    Lighting noise (see https://git.io/fhBOc)
    """

    def __init__(self, alphastd, eigval, eigvec):
        self.alphastd = alphastd
        self.eigval = eigval
        self.eigvec = eigvec

    def __call__(self, img):
        if self.alphastd == 0:
            return img

        alpha = img.new().resize_(3).normal_(0, self.alphastd)
        rgb = self.eigvec.type_as(img).clone() \
            .mul(alpha.view(1, 3).expand(3, 3)) \
            .mul(self.eigval.view(1, 3).expand(3, 3)) \
            .sum(1).squeeze()

        return img.add(rgb.view(3, 1, 1).expand_as(img))
IMAGENET_PCA = {
    'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
    'eigvec': torch.Tensor([
        [-0.5675, 0.7192, 0.4009],
        [-0.5808, -0.0045, -0.8140],
        [-0.5836, -0.6948, 0.4203],
    ])
}