import torch from torchvision import transforms as transforms from torchvision.transforms import Compose from timm.data.constants import \ IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD def make_test_transforms(image_size): test_transforms: Compose = transforms.Compose([ transforms.Resize(size=image_size, antialias=True), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) ]) return test_transforms def inverse_normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD): mean = torch.as_tensor(mean) std = torch.as_tensor(std) un_normalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()) return un_normalize def normalize_only(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD): normalize = transforms.Normalize(mean=mean, std=std) return normalize def inverse_normalize_w_resize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, resize_resolution=(256, 256)): mean = torch.as_tensor(mean) std = torch.as_tensor(std) resize_unnorm = transforms.Compose([ transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()), transforms.Resize(size=resize_resolution, antialias=True)]) return resize_unnorm