John6666's picture
Upload 351 files
e84842d verified
raw
history blame
3.29 kB
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
Based on https://github.com/mlfoundations/open_clip
"""
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.transforms import (
Normalize,
Compose,
RandomResizedCrop,
InterpolationMode,
ToTensor,
Resize,
CenterCrop,
)
class ResizeMaxSize(nn.Module):
def __init__(
self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0
):
super().__init__()
if not isinstance(max_size, int):
raise TypeError(f"Size should be int. Got {type(max_size)}")
self.max_size = max_size
self.interpolation = interpolation
self.fn = min if fn == "min" else min
self.fill = fill
def forward(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(
img,
padding=[
pad_w // 2,
pad_h // 2,
pad_w - pad_w // 2,
pad_h - pad_h // 2,
],
fill=self.fill,
)
return img
def _convert_to_rgb(image):
return image.convert("RGB")
def image_transform(
image_size: int,
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
):
mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean
std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size = image_size[0]
normalize = Normalize(mean=mean, std=std)
if is_train:
return Compose(
[
RandomResizedCrop(
image_size,
scale=(0.9, 1.0),
interpolation=InterpolationMode.BICUBIC,
),
_convert_to_rgb,
ToTensor(),
normalize,
]
)
else:
if resize_longest_max:
transforms = [ResizeMaxSize(image_size, fill=fill_color)]
else:
transforms = [
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
]
transforms.extend(
[
_convert_to_rgb,
ToTensor(),
normalize,
]
)
return Compose(transforms)