Spaces:
Running
Running
File size: 5,015 Bytes
9b896f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from pathlib import Path
import torch
import torchvision
from torchvision.transforms import transforms, TrivialAugmentWide
from configs.dataset_params import normalize_params
from dataset_classes.cub200 import CUB200Class
from dataset_classes.stanfordcars import StanfordCarsClass
from dataset_classes.travelingbirds import TravelingBirds
def get_data(dataset, crop = True, img_size=448):
batchsize = 16
if dataset == "CUB2011":
train_transform = get_augmentation(0.1, img_size, True,not crop, True, True, normalize_params["CUB2011"])
test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["CUB2011"])
train_dataset = CUB200Class(True, train_transform, crop)
test_dataset = CUB200Class(False, test_transform, crop)
elif dataset == "TravelingBirds":
train_transform = get_augmentation(0.1, img_size, True, not crop, True, True, normalize_params["TravelingBirds"])
test_transform = get_augmentation(0.1, img_size, False, not crop, True, True, normalize_params["TravelingBirds"])
train_dataset = TravelingBirds(True, train_transform, crop)
test_dataset = TravelingBirds(False, test_transform, crop)
elif dataset == "StanfordCars":
train_transform = get_augmentation(0.1, img_size, True, True, True, True, normalize_params["StanfordCars"])
test_transform = get_augmentation(0.1, img_size, False, True, True, True, normalize_params["StanfordCars"])
train_dataset = StanfordCarsClass(True, train_transform)
test_dataset = StanfordCarsClass(False, test_transform)
elif dataset == "FGVCAircraft":
raise NotImplementedError
elif dataset == "ImageNet":
# Defaults from the robustness package
if img_size != 224:
raise NotImplementedError("ImageNet is setup to only work with 224x224 images")
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(
brightness=0.1,
contrast=0.1,
saturation=0.1
),
transforms.ToTensor(),
Lighting(0.05, IMAGENET_PCA['eigval'],
IMAGENET_PCA['eigvec'])
])
"""
Standard training data augmentation for ImageNet-scale datasets: Random crop,
Random flip, Color Jitter, and Lighting Transform (see https://git.io/fhBOc)
"""
test_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
imgnet_root = Path.home()/ "tmp" /"Datasets"/ "imagenet"
train_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='train', transform=train_transform)
test_dataset = torchvision.datasets.ImageNet(root=imgnet_root, split='val', transform=test_transform)
batchsize = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batchsize, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batchsize, shuffle=False, num_workers=8)
return train_loader, test_loader
def get_augmentation(jitter, size, training, random_center_crop, trivialAug, hflip, normalize):
augmentation = []
if random_center_crop:
augmentation.append(transforms.Resize(size))
else:
augmentation.append(transforms.Resize((size, size)))
if training:
if random_center_crop:
augmentation.append(transforms.RandomCrop(size, padding=4))
else:
if random_center_crop:
augmentation.append(transforms.CenterCrop(size))
if training:
if hflip:
augmentation.append(transforms.RandomHorizontalFlip())
if jitter:
augmentation.append(transforms.ColorJitter(jitter, jitter, jitter))
if trivialAug:
augmentation.append(TrivialAugmentWide())
augmentation.append(transforms.ToTensor())
augmentation.append(transforms.Normalize(**normalize))
return transforms.Compose(augmentation)
class Lighting(object):
"""
Lighting noise (see https://git.io/fhBOc)
"""
def __init__(self, alphastd, eigval, eigvec):
self.alphastd = alphastd
self.eigval = eigval
self.eigvec = eigvec
def __call__(self, img):
if self.alphastd == 0:
return img
alpha = img.new().resize_(3).normal_(0, self.alphastd)
rgb = self.eigvec.type_as(img).clone() \
.mul(alpha.view(1, 3).expand(3, 3)) \
.mul(self.eigval.view(1, 3).expand(3, 3)) \
.sum(1).squeeze()
return img.add(rgb.view(3, 1, 1).expand_as(img))
IMAGENET_PCA = {
'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
'eigvec': torch.Tensor([
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
])
}
|