from pathlib import Path import torch import torchvision from torchvision.transforms import transforms, TrivialAugmentWide from configs.dataset_params import normalize_params from dataset_classes.cub200 import CUB200Class from dataset_classes.stanfordcars import StanfordCarsClass from dataset_classes.travelingbirds import TravelingBirds def get_data(dataset, crop = True, img_size=448): batchsize = 16 if dataset == "CUB2011": train_transform = get_augmentation(0.1, img_size, True,not crop, True, True, normalize_params["CUB2011"]) test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["CUB2011"]) train_dataset = CUB200Class(True, train_transform, crop) test_dataset = CUB200Class(False, test_transform, crop) elif dataset == "TravelingBirds": train_transform = get_augmentation(0.1, img_size, True, not crop, True, True, normalize_params["TravelingBirds"]) test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["TravelingBirds"]) train_dataset = TravelingBirds(True, train_transform, crop) test_dataset = TravelingBirds(False, test_transform, crop) elif dataset == "StanfordCars": train_transform = get_augmentation(0.1, img_size, True, True, True, True, normalize_params["StanfordCars"]) test_transform = get_augmentation(0.1, img_size, False, True, True, True, normalize_params["StanfordCars"]) train_dataset = StanfordCarsClass(True, train_transform) test_dataset = StanfordCarsClass(False, test_transform) elif dataset == "FGVCAircraft": raise NotImplementedError elif dataset == "ImageNet": # Defaults from the robustness package if img_size != 224: raise NotImplementedError("ImageNet is setup to only work with 224x224 images") train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter( brightness=0.1, contrast=0.1, saturation=0.1 ), transforms.ToTensor(), Lighting(0.05, IMAGENET_PCA['eigval'], IMAGENET_PCA['eigvec']) ]) """ Standard training data augmentation for ImageNet-scale datasets: Random crop, Random flip, Color Jitter, and Lighting Transform (see https://git.io/fhBOc) """ test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), ]) imgnet_root = Path.home()/ "tmp" /"Datasets"/ "imagenet" train_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='train', transform=train_transform) test_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='val', transform=test_transform) batchsize = 64 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8) return train_loader, test_loader def get_augmentation(jitter, size, training, random_center_crop, trivialAug, hflip, normalize): augmentation = [] if random_center_crop: augmentation.append(transforms.Resize(size)) else: augmentation.append(transforms.Resize((size, size))) if training: if random_center_crop: augmentation.append(transforms.RandomCrop(size, padding=4)) else: if random_center_crop: augmentation.append(transforms.CenterCrop(size)) if training: if hflip: augmentation.append(transforms.RandomHorizontalFlip()) if jitter: augmentation.append(transforms.ColorJitter(jitter, jitter, jitter)) if trivialAug: augmentation.append(TrivialAugmentWide()) augmentation.append(transforms.ToTensor()) augmentation.append(transforms.Normalize(**normalize)) return transforms.Compose(augmentation) class Lighting(object): """ Lighting noise (see https://git.io/fhBOc) """ def __init__(self, alphastd, eigval, eigvec): self.alphastd = alphastd self.eigval = eigval self.eigvec = eigvec def __call__(self, img): if self.alphastd == 0: return img alpha = img.new().resize_(3).normal_(0, self.alphastd) rgb = self.eigvec.type_as(img).clone() \ .mul(alpha.view(1, 3).expand(3, 3)) \ .mul(self.eigval.view(1, 3).expand(3, 3)) \ .sum(1).squeeze() return img.add(rgb.view(3, 1, 1).expand_as(img)) IMAGENET_PCA = { 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 'eigvec': torch.Tensor([ [-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203], ]) }