Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2021 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Feature extractor class for CLIP.""" | |
from typing import List, Optional, Union | |
import numpy as np | |
from PIL import Image | |
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin | |
from ...file_utils import TensorType | |
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor | |
from ...utils import logging | |
logger = logging.get_logger(__name__) | |
class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): | |
r""" | |
Constructs a CLIP feature extractor. | |
This feature extractor inherits from :class:`~transformers.FeatureExtractionMixin` which contains most of the main | |
methods. Users should refer to this superclass for more information regarding those methods. | |
Args: | |
do_resize (:obj:`bool`, `optional`, defaults to :obj:`True`): | |
Whether to resize the input to a certain :obj:`size`. | |
size (:obj:`int`, `optional`, defaults to 224): | |
Resize the input to the given size. Only has an effect if :obj:`do_resize` is set to :obj:`True`. | |
resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BICUBIC`): | |
An optional resampling filter. This can be one of :obj:`PIL.Image.NEAREST`, :obj:`PIL.Image.BOX`, | |
:obj:`PIL.Image.BILINEAR`, :obj:`PIL.Image.HAMMING`, :obj:`PIL.Image.BICUBIC` or :obj:`PIL.Image.LANCZOS`. | |
Only has an effect if :obj:`do_resize` is set to :obj:`True`. | |
do_center_crop (:obj:`bool`, `optional`, defaults to :obj:`True`): | |
Whether to crop the input at the center. If the input size is smaller than :obj:`crop_size` along any edge, | |
the image is padded with 0's and then center cropped. | |
crop_size (:obj:`int`, `optional`, defaults to 224): | |
Desired output size when applying center-cropping. Only has an effect if :obj:`do_center_crop` is set to | |
:obj:`True`. | |
do_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`): | |
Whether or not to normalize the input with :obj:`image_mean` and :obj:`image_std`. | |
image_mean (:obj:`List[int]`, defaults to :obj:`[0.485, 0.456, 0.406]`): | |
The sequence of means for each channel, to be used when normalizing images. | |
image_std (:obj:`List[int]`, defaults to :obj:`[0.229, 0.224, 0.225]`): | |
The sequence of standard deviations for each channel, to be used when normalizing images. | |
""" | |
model_input_names = ["pixel_values"] | |
def __init__( | |
self, | |
do_resize=True, | |
size=224, | |
resample=Image.BICUBIC, | |
do_center_crop=True, | |
crop_size=224, | |
do_normalize=True, | |
image_mean=None, | |
image_std=None, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
self.do_resize = do_resize | |
self.size = size | |
self.resample = resample | |
self.do_center_crop = do_center_crop | |
self.crop_size = crop_size | |
self.do_normalize = do_normalize | |
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073] | |
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711] | |
def __call__( | |
self, | |
images: Union[ | |
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa | |
], | |
return_tensors: Optional[Union[str, TensorType]] = None, | |
**kwargs | |
) -> BatchFeature: | |
""" | |
Main method to prepare for the model one or several image(s). | |
.. warning:: | |
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass | |
PIL images. | |
Args: | |
images (:obj:`PIL.Image.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`): | |
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch | |
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a | |
number of channels, H and W are image height and width. | |
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`, defaults to :obj:`'np'`): | |
If set, will return tensors of a particular framework. Acceptable values are: | |
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. | |
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. | |
* :obj:`'np'`: Return NumPy :obj:`np.ndarray` objects. | |
* :obj:`'jax'`: Return JAX :obj:`jnp.ndarray` objects. | |
Returns: | |
:class:`~transformers.BatchFeature`: A :class:`~transformers.BatchFeature` with the following fields: | |
- **pixel_values** -- Pixel values to be fed to a model. | |
""" | |
# Input type checking for clearer error | |
valid_images = False | |
# Check that images has a valid type | |
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): | |
valid_images = True | |
elif isinstance(images, (list, tuple)): | |
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]): | |
valid_images = True | |
if not valid_images: | |
raise ValueError( | |
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example)," | |
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." | |
) | |
is_batched = bool( | |
isinstance(images, (list, tuple)) | |
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0])) | |
) | |
if not is_batched: | |
images = [images] | |
# transformations (resizing + center cropping + normalization) | |
if self.do_resize and self.size is not None and self.resample is not None: | |
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images] | |
if self.do_center_crop and self.crop_size is not None: | |
images = [self.center_crop(image, self.crop_size) for image in images] | |
if self.do_normalize: | |
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] | |
# return as BatchFeature | |
data = {"pixel_values": images} | |
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) | |
return encoded_inputs | |
def center_crop(self, image, size): | |
""" | |
Crops :obj:`image` to the given size using a center crop. Note that if the image is too small to be cropped to | |
the size is given, it will be padded (so the returned result has the size asked). | |
Args: | |
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`): | |
The image to resize. | |
size (:obj:`int` or :obj:`Tuple[int, int]`): | |
The size to which crop the image. | |
""" | |
self._ensure_format_supported(image) | |
if not isinstance(size, tuple): | |
size = (size, size) | |
if not isinstance(image, Image.Image): | |
image = self.to_pil_image(image) | |
image_width, image_height = image.size | |
crop_height, crop_width = size | |
crop_top = int((image_height - crop_height + 1) * 0.5) | |
crop_left = int((image_width - crop_width + 1) * 0.5) | |
return image.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height)) | |
def resize(self, image, size, resample=Image.BICUBIC): | |
""" | |
Resizes :obj:`image`. Note that this will trigger a conversion of :obj:`image` to a PIL Image. | |
Args: | |
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`): | |
The image to resize. | |
size (:obj:`int` or :obj:`Tuple[int, int]`): | |
The size to use for resizing the image. If :obj:`int` it will be resized to match the shorter side | |
resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BILINEAR`): | |
The filter to user for resampling. | |
""" | |
self._ensure_format_supported(image) | |
if not isinstance(image, Image.Image): | |
image = self.to_pil_image(image) | |
if isinstance(size, tuple): | |
new_w, new_h = size | |
else: | |
width, height = image.size | |
short, long = (width, height) if width <= height else (height, width) | |
if short == size: | |
return image | |
new_short, new_long = size, int(size * long / short) | |
new_w, new_h = (new_short, new_long) if width <= height else (new_long, new_short) | |
return image.resize((new_w, new_h), resample) | |