| from torchvision import transforms | |
| def get_transforms(train=True): | |
| if train: | |
| return transforms.Compose([ | |
| transforms.Resize((150, 150)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| else: | |
| return transforms.Compose([ | |
| transforms.Resize((150, 150)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |