""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause Based on https://github.com/mlfoundations/open_clip """ from typing import Optional, Sequence, Tuple import torch import torch.nn as nn import torchvision.transforms.functional as F from torchvision.transforms import ( Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop, ) class ResizeMaxSize(nn.Module): def __init__( self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0 ): super().__init__() if not isinstance(max_size, int): raise TypeError(f"Size should be int. Got {type(max_size)}") self.max_size = max_size self.interpolation = interpolation self.fn = min if fn == "min" else min self.fill = fill def forward(self, img): if isinstance(img, torch.Tensor): height, width = img.shape[:2] else: width, height = img.size scale = self.max_size / float(max(height, width)) if scale != 1.0: new_size = tuple(round(dim * scale) for dim in (height, width)) img = F.resize(img, new_size, self.interpolation) pad_h = self.max_size - new_size[0] pad_w = self.max_size - new_size[1] img = F.pad( img, padding=[ pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2, ], fill=self.fill, ) return img def _convert_to_rgb(image): return image.convert("RGB") def image_transform( image_size: int, is_train: bool, mean: Optional[Tuple[float, ...]] = None, std: Optional[Tuple[float, ...]] = None, resize_longest_max: bool = False, fill_color: int = 0, ): mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: # for square size, pass size as int so that Resize() uses aspect preserving shortest edge image_size = image_size[0] normalize = Normalize(mean=mean, std=std) if is_train: return Compose( [ RandomResizedCrop( image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC, ), _convert_to_rgb, ToTensor(), normalize, ] ) else: if resize_longest_max: transforms = [ResizeMaxSize(image_size, fill=fill_color)] else: transforms = [ Resize(image_size, interpolation=InterpolationMode.BICUBIC), CenterCrop(image_size), ] transforms.extend( [ _convert_to_rgb, ToTensor(), normalize, ] ) return Compose(transforms)