Spaces:
Running
Running
from pathlib import Path | |
import torch | |
import torchvision | |
from torchvision.transforms import transforms, TrivialAugmentWide | |
from configs.dataset_params import normalize_params | |
from dataset_classes.cub200 import CUB200Class | |
from dataset_classes.stanfordcars import StanfordCarsClass | |
from dataset_classes.travelingbirds import TravelingBirds | |
def get_data(dataset, crop = True, img_size=448): | |
batchsize = 16 | |
if dataset == "CUB2011": | |
train_transform = get_augmentation(0.1, img_size, True,not crop, True, True, normalize_params["CUB2011"]) | |
test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["CUB2011"]) | |
train_dataset = CUB200Class(True, train_transform, crop) | |
test_dataset = CUB200Class(False, test_transform, crop) | |
elif dataset == "TravelingBirds": | |
train_transform = get_augmentation(0.1, img_size, True, not crop, True, True, normalize_params["TravelingBirds"]) | |
test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["TravelingBirds"]) | |
train_dataset = TravelingBirds(True, train_transform, crop) | |
test_dataset = TravelingBirds(False, test_transform, crop) | |
elif dataset == "StanfordCars": | |
train_transform = get_augmentation(0.1, img_size, True, True, True, True, normalize_params["StanfordCars"]) | |
test_transform = get_augmentation(0.1, img_size, False, True, True, True, normalize_params["StanfordCars"]) | |
train_dataset = StanfordCarsClass(True, train_transform) | |
test_dataset = StanfordCarsClass(False, test_transform) | |
elif dataset == "FGVCAircraft": | |
raise NotImplementedError | |
elif dataset == "ImageNet": | |
# Defaults from the robustness package | |
if img_size != 224: | |
raise NotImplementedError("ImageNet is setup to only work with 224x224 images") | |
train_transform = transforms.Compose([ | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ColorJitter( | |
brightness=0.1, | |
contrast=0.1, | |
saturation=0.1 | |
), | |
transforms.ToTensor(), | |
Lighting(0.05, IMAGENET_PCA['eigval'], | |
IMAGENET_PCA['eigvec']) | |
]) | |
""" | |
Standard training data augmentation for ImageNet-scale datasets: Random crop, | |
Random flip, Color Jitter, and Lighting Transform (see https://git.io/fhBOc) | |
""" | |
test_transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
]) | |
imgnet_root = Path.home()/ "tmp" /"Datasets"/ "imagenet" | |
train_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='train', transform=train_transform) | |
test_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='val', transform=test_transform) | |
batchsize = 64 | |
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8) | |
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8) | |
return train_loader, test_loader | |
def get_augmentation(jitter, size, training, random_center_crop, trivialAug, hflip, normalize): | |
augmentation = [] | |
if random_center_crop: | |
augmentation.append(transforms.Resize(size)) | |
else: | |
augmentation.append(transforms.Resize((size, size))) | |
if training: | |
if random_center_crop: | |
augmentation.append(transforms.RandomCrop(size, padding=4)) | |
else: | |
if random_center_crop: | |
augmentation.append(transforms.CenterCrop(size)) | |
if training: | |
if hflip: | |
augmentation.append(transforms.RandomHorizontalFlip()) | |
if jitter: | |
augmentation.append(transforms.ColorJitter(jitter, jitter, jitter)) | |
if trivialAug: | |
augmentation.append(TrivialAugmentWide()) | |
augmentation.append(transforms.ToTensor()) | |
augmentation.append(transforms.Normalize(**normalize)) | |
return transforms.Compose(augmentation) | |
class Lighting(object): | |
""" | |
Lighting noise (see https://git.io/fhBOc) | |
""" | |
def __init__(self, alphastd, eigval, eigvec): | |
self.alphastd = alphastd | |
self.eigval = eigval | |
self.eigvec = eigvec | |
def __call__(self, img): | |
if self.alphastd == 0: | |
return img | |
alpha = img.new().resize_(3).normal_(0, self.alphastd) | |
rgb = self.eigvec.type_as(img).clone() \ | |
.mul(alpha.view(1, 3).expand(3, 3)) \ | |
.mul(self.eigval.view(1, 3).expand(3, 3)) \ | |
.sum(1).squeeze() | |
return img.add(rgb.view(3, 1, 1).expand_as(img)) | |
IMAGENET_PCA = { | |
'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), | |
'eigvec': torch.Tensor([ | |
[-0.5675, 0.7192, 0.4009], | |
[-0.5808, -0.0045, -0.8140], | |
[-0.5836, -0.6948, 0.4203], | |
]) | |
} | |