from torchvision import transforms def get_transforms(train=True): if train: return transforms.Compose([ transforms.Resize((150, 150)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: return transforms.Compose([ transforms.Resize((150, 150)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])