from torchvision import transforms | |
def get_transforms(train=True): | |
if train: | |
return transforms.Compose([ | |
transforms.Resize((150, 150)), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
else: | |
return transforms.Compose([ | |
transforms.Resize((150, 150)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |