Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
Based on https://github.com/mlfoundations/open_clip | |
""" | |
from typing import Optional, Sequence, Tuple | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms.functional as F | |
from torchvision.transforms import ( | |
Normalize, | |
Compose, | |
RandomResizedCrop, | |
InterpolationMode, | |
ToTensor, | |
Resize, | |
CenterCrop, | |
) | |
class ResizeMaxSize(nn.Module): | |
def __init__( | |
self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0 | |
): | |
super().__init__() | |
if not isinstance(max_size, int): | |
raise TypeError(f"Size should be int. Got {type(max_size)}") | |
self.max_size = max_size | |
self.interpolation = interpolation | |
self.fn = min if fn == "min" else min | |
self.fill = fill | |
def forward(self, img): | |
if isinstance(img, torch.Tensor): | |
height, width = img.shape[:2] | |
else: | |
width, height = img.size | |
scale = self.max_size / float(max(height, width)) | |
if scale != 1.0: | |
new_size = tuple(round(dim * scale) for dim in (height, width)) | |
img = F.resize(img, new_size, self.interpolation) | |
pad_h = self.max_size - new_size[0] | |
pad_w = self.max_size - new_size[1] | |
img = F.pad( | |
img, | |
padding=[ | |
pad_w // 2, | |
pad_h // 2, | |
pad_w - pad_w // 2, | |
pad_h - pad_h // 2, | |
], | |
fill=self.fill, | |
) | |
return img | |
def _convert_to_rgb(image): | |
return image.convert("RGB") | |
def image_transform( | |
image_size: int, | |
is_train: bool, | |
mean: Optional[Tuple[float, ...]] = None, | |
std: Optional[Tuple[float, ...]] = None, | |
resize_longest_max: bool = False, | |
fill_color: int = 0, | |
): | |
mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean | |
std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std | |
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: | |
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge | |
image_size = image_size[0] | |
normalize = Normalize(mean=mean, std=std) | |
if is_train: | |
return Compose( | |
[ | |
RandomResizedCrop( | |
image_size, | |
scale=(0.9, 1.0), | |
interpolation=InterpolationMode.BICUBIC, | |
), | |
_convert_to_rgb, | |
ToTensor(), | |
normalize, | |
] | |
) | |
else: | |
if resize_longest_max: | |
transforms = [ResizeMaxSize(image_size, fill=fill_color)] | |
else: | |
transforms = [ | |
Resize(image_size, interpolation=InterpolationMode.BICUBIC), | |
CenterCrop(image_size), | |
] | |
transforms.extend( | |
[ | |
_convert_to_rgb, | |
ToTensor(), | |
normalize, | |
] | |
) | |
return Compose(transforms) | |