|
from pathlib import Path |
|
|
|
import torch |
|
import torchvision |
|
from torchvision.transforms import transforms, TrivialAugmentWide |
|
|
|
from configs.dataset_params import normalize_params |
|
from dataset_classes.cub200 import CUB200Class |
|
from dataset_classes.stanfordcars import StanfordCarsClass |
|
from dataset_classes.travelingbirds import TravelingBirds |
|
|
|
|
|
def get_data(dataset, crop = True, img_size=448): |
|
batchsize = 16 |
|
if dataset == "CUB2011": |
|
train_transform = get_augmentation(0.1, img_size, True,not crop, True, True, normalize_params["CUB2011"]) |
|
test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["CUB2011"]) |
|
train_dataset = CUB200Class(True, train_transform, crop) |
|
test_dataset = CUB200Class(False, test_transform, crop) |
|
elif dataset == "TravelingBirds": |
|
train_transform = get_augmentation(0.1, img_size, True, not crop, True, True, normalize_params["TravelingBirds"]) |
|
test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["TravelingBirds"]) |
|
train_dataset = TravelingBirds(True, train_transform, crop) |
|
test_dataset = TravelingBirds(False, test_transform, crop) |
|
|
|
elif dataset == "StanfordCars": |
|
train_transform = get_augmentation(0.1, img_size, True, True, True, True, normalize_params["StanfordCars"]) |
|
test_transform = get_augmentation(0.1, img_size, False, True, True, True, normalize_params["StanfordCars"]) |
|
train_dataset = StanfordCarsClass(True, train_transform) |
|
test_dataset = StanfordCarsClass(False, test_transform) |
|
elif dataset == "FGVCAircraft": |
|
raise NotImplementedError |
|
|
|
elif dataset == "ImageNet": |
|
|
|
if img_size != 224: |
|
raise NotImplementedError("ImageNet is setup to only work with 224x224 images") |
|
train_transform = transforms.Compose([ |
|
transforms.RandomResizedCrop(224), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ColorJitter( |
|
brightness=0.1, |
|
contrast=0.1, |
|
saturation=0.1 |
|
), |
|
transforms.ToTensor(), |
|
Lighting(0.05, IMAGENET_PCA['eigval'], |
|
IMAGENET_PCA['eigvec']) |
|
]) |
|
""" |
|
Standard training data augmentation for ImageNet-scale datasets: Random crop, |
|
Random flip, Color Jitter, and Lighting Transform (see https://git.io/fhBOc) |
|
""" |
|
test_transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
]) |
|
imgnet_root = Path.home()/ "tmp" /"Datasets"/ "imagenet" |
|
train_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='train', transform=train_transform) |
|
test_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='val', transform=test_transform) |
|
batchsize = 64 |
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8) |
|
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8) |
|
return train_loader, test_loader |
|
|
|
def get_augmentation(jitter, size, training, random_center_crop, trivialAug, hflip, normalize): |
|
augmentation = [] |
|
if random_center_crop: |
|
augmentation.append(transforms.Resize(size)) |
|
else: |
|
augmentation.append(transforms.Resize((size, size))) |
|
if training: |
|
if random_center_crop: |
|
augmentation.append(transforms.RandomCrop(size, padding=4)) |
|
else: |
|
if random_center_crop: |
|
augmentation.append(transforms.CenterCrop(size)) |
|
if training: |
|
if hflip: |
|
augmentation.append(transforms.RandomHorizontalFlip()) |
|
if jitter: |
|
augmentation.append(transforms.ColorJitter(jitter, jitter, jitter)) |
|
if trivialAug: |
|
augmentation.append(TrivialAugmentWide()) |
|
augmentation.append(transforms.ToTensor()) |
|
augmentation.append(transforms.Normalize(**normalize)) |
|
return transforms.Compose(augmentation) |
|
|
|
class Lighting(object): |
|
""" |
|
Lighting noise (see https://git.io/fhBOc) |
|
""" |
|
|
|
def __init__(self, alphastd, eigval, eigvec): |
|
self.alphastd = alphastd |
|
self.eigval = eigval |
|
self.eigvec = eigvec |
|
|
|
def __call__(self, img): |
|
if self.alphastd == 0: |
|
return img |
|
|
|
alpha = img.new().resize_(3).normal_(0, self.alphastd) |
|
rgb = self.eigvec.type_as(img).clone() \ |
|
.mul(alpha.view(1, 3).expand(3, 3)) \ |
|
.mul(self.eigval.view(1, 3).expand(3, 3)) \ |
|
.sum(1).squeeze() |
|
|
|
return img.add(rgb.view(3, 1, 1).expand_as(img)) |
|
IMAGENET_PCA = { |
|
'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), |
|
'eigvec': torch.Tensor([ |
|
[-0.5675, 0.7192, 0.4009], |
|
[-0.5808, -0.0045, -0.8140], |
|
[-0.5836, -0.6948, 0.4203], |
|
]) |
|
} |
|
|