Spaces:
Runtime error
Runtime error
| # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| base_vision.py | |
| Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility | |
| functions, and initialization logic. | |
| We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision | |
| Transformer model for feature extraction. | |
| """ | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from functools import partial | |
| from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union | |
| import timm | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms.functional as TVF | |
| from PIL.Image import Image | |
| from timm.models.vision_transformer import Block, VisionTransformer | |
| from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy | |
| from torchvision.transforms import Compose, Resize | |
| # === Utility Functions for Monkey-Patching === | |
| def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: | |
| def wrapper(*args: Any, **kwargs: Any) -> Any: | |
| result = fn(*args, **kwargs) | |
| return result[0] if (isinstance(result, tuple) or isinstance(result, list)) else result | |
| return wrapper | |
| def return_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]: | |
| def wrapper(*args: Any, **kwargs: Any) -> Any: | |
| result = fn(*args, **kwargs) | |
| return result | |
| return wrapper | |
| # === Interface for an Image Transform === | |
| class ImageTransform(Protocol): | |
| def __call__(self, img: Image, **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ... | |
| # === Custom Torchvision Image Transforms === | |
| class LetterboxPad: | |
| padding_fill_value: Tuple[int, int, int] | |
| def __call__(self, image: Image) -> Image: | |
| """Given a PIL.Image, pad to square by adding a symmetric border around the height/width.""" | |
| (w, h), max_wh = image.size, max(image.size) | |
| horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2) | |
| padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) | |
| return TVF.pad(image, padding, fill=self.padding_fill_value, padding_mode="constant") | |
| # === Abstract Base Class for arbitrary Vision Backbones === | |
| class VisionBackbone(nn.Module, ABC): | |
| def __init__(self, vision_backbone_id: str, image_resize_strategy: str, default_image_size: int = 224) -> None: | |
| super().__init__() | |
| self.identifier: str = vision_backbone_id | |
| self.image_resize_strategy: str = image_resize_strategy | |
| self.default_image_size: int = default_image_size | |
| # Instance attributes for a Vision Backbone | |
| self.featurizer: nn.Module = None | |
| self.image_transform: ImageTransform = None | |
| def get_image_transform(self) -> ImageTransform: | |
| return self.image_transform | |
| def get_fsdp_wrapping_policy(self) -> Callable: ... | |
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: | |
| """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features.""" | |
| raise NotImplementedError | |
| def default_image_resolution(self) -> Tuple[int, int, int]: ... | |
| def embed_dim(self) -> int: ... | |
| def num_patches(self) -> int: ... | |
| def half_precision_dtype(self) -> torch.dtype: ... | |
| # === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones === | |
| class TimmViTBackbone(VisionBackbone, ABC): | |
| def __init__( | |
| self, | |
| vision_backbone_id: str, | |
| timm_path_or_url: str, | |
| image_resize_strategy: str, | |
| default_image_size: int = 224, | |
| override_act_layer: Optional[str] = None, | |
| ) -> None: | |
| super().__init__(vision_backbone_id, image_resize_strategy, default_image_size=default_image_size) | |
| self.timm_path_or_url = timm_path_or_url | |
| self.override_act_layer = override_act_layer | |
| self.dtype = torch.bfloat16 | |
| # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary | |
| if self.override_act_layer is None: | |
| self.featurizer: VisionTransformer = timm.create_model( | |
| self.timm_path_or_url, pretrained=True, num_classes=0, img_size=self.default_image_size, | |
| ) | |
| else: | |
| self.featurizer: VisionTransformer = timm.create_model( | |
| self.timm_path_or_url, | |
| pretrained=True, | |
| num_classes=0, | |
| img_size=self.default_image_size, | |
| act_layer=self.override_act_layer, | |
| ) | |
| self.featurizer.eval() | |
| # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility | |
| # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches! | |
| # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385 | |
| self.featurizer.forward = unpack_tuple( | |
| partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2}) | |
| ) | |
| # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!) | |
| assert isinstance(self.featurizer, VisionTransformer), ( | |
| "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, " | |
| "file an issue or implement the requisite logic (see `cobra/models/backbones/vision/base_vision.py`)!" | |
| ) | |
| # Get Config =>> Note :: Override default image size to ensure correct image transform | |
| self.data_cfg = timm.data.resolve_model_data_config(self.featurizer) | |
| self.data_cfg["input_size"] = (3, self.default_image_size, self.default_image_size) | |
| # Initialize Default Image Transform --> Modified by `self.image_resize_strategy` | |
| default_image_transform = timm.data.create_transform(**self.data_cfg, is_training=False) | |
| # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)! | |
| if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url: | |
| assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" | |
| assert isinstance(resize_transform := default_image_transform.transforms[0], Resize) | |
| default_image_transform = Compose( | |
| [ | |
| Resize(self.default_image_size, interpolation=resize_transform.interpolation), | |
| *default_image_transform.transforms[1:], | |
| ] | |
| ) | |
| # Switch on `image_resize_strategy` | |
| if self.image_resize_strategy == "resize-naive": | |
| assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" | |
| assert isinstance(resize_transform := default_image_transform.transforms[0], Resize) | |
| target_size = (self.default_image_size, self.default_image_size) | |
| self.image_transform = Compose( | |
| [ | |
| Resize(target_size, interpolation=resize_transform.interpolation), | |
| *default_image_transform.transforms[1:], | |
| ] | |
| ) | |
| elif self.image_resize_strategy == "resize-crop": | |
| self.image_transform = default_image_transform | |
| elif self.image_resize_strategy == "letterbox": | |
| assert isinstance(default_image_transform, Compose), "Unexpected `default_image_transform`!" | |
| assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!" | |
| # Compute Padding Fill Value (rescaled normalization mean if applicable) | |
| fill = tuple([int(x * 255) for x in self.data_cfg["mean"]]) | |
| # Build New Transform | |
| self.image_transform = Compose([LetterboxPad(fill), *default_image_transform.transforms]) | |
| else: | |
| raise ValueError(f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!") | |
| def get_fsdp_wrapping_policy(self) -> Callable: | |
| """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer.""" | |
| vit_wrap_policy = partial(_module_wrap_policy, module_classes={VisionTransformer}) | |
| transformer_block_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block}) | |
| return partial(_or_policy, policies=[vit_wrap_policy, transformer_block_policy]) | |
| def forward(self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: | |
| """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features.""" | |
| return self.featurizer(pixel_values) | |
| def default_image_resolution(self) -> Tuple[int, int, int]: | |
| return self.data_cfg["input_size"] | |
| def embed_dim(self) -> int: | |
| return self.featurizer.embed_dim | |
| def num_patches(self) -> int: | |
| return self.featurizer.patch_embed.num_patches | |
| def half_precision_dtype(self) -> torch.dtype: | |
| return self.dtype | |