Spaces:
Build error
Build error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| Based on https://github.com/mlfoundations/open_clip | |
| """ | |
| """ timm model adapter | |
| Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. | |
| """ | |
| import math | |
| import warnings | |
| from collections import OrderedDict | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch import nn as nn | |
| try: | |
| import timm | |
| from timm.models.layers import Mlp, to_2tuple | |
| # from timm.models.layers.attention_pool2d import RotAttentionPool2d | |
| # from timm.models.layers.attention_pool2d import ( | |
| # AttentionPool2d as AbsAttentionPool2d, | |
| # ) | |
| except ImportError as e: | |
| timm = None | |
| from lavis.models.clip_models.utils import freeze_batch_norm_2d | |
| class TimmModel(nn.Module): | |
| """timm model adapter | |
| # FIXME this adapter is a work in progress, may change in ways that break weight compat | |
| """ | |
| def __init__( | |
| self, | |
| model_name, | |
| embed_dim, | |
| image_size=224, | |
| pool="avg", | |
| proj="linear", | |
| drop=0.0, | |
| pretrained=False, | |
| ): | |
| super().__init__() | |
| if timm is None: | |
| raise RuntimeError("Please `pip install timm` to use timm models.") | |
| self.image_size = to_2tuple(image_size) | |
| self.trunk = timm.create_model(model_name, pretrained=pretrained) | |
| feat_size = self.trunk.default_cfg.get("pool_size", None) | |
| feature_ndim = 1 if not feat_size else 2 | |
| if pool in ("abs_attn", "rot_attn"): | |
| assert feature_ndim == 2 | |
| # if attn pooling used, remove both classifier and default pool | |
| self.trunk.reset_classifier(0, global_pool="") | |
| else: | |
| # reset global pool if pool config set, otherwise leave as network default | |
| reset_kwargs = dict(global_pool=pool) if pool else {} | |
| self.trunk.reset_classifier(0, **reset_kwargs) | |
| prev_chs = self.trunk.num_features | |
| head_layers = OrderedDict() | |
| if pool == "abs_attn": | |
| head_layers["pool"] = AttentionPool2d( | |
| prev_chs, feat_size=feat_size, out_features=embed_dim | |
| ) | |
| prev_chs = embed_dim | |
| elif pool == "rot_attn": | |
| head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) | |
| prev_chs = embed_dim | |
| else: | |
| assert proj, "projection layer needed if non-attention pooling is used." | |
| # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used | |
| if proj == "linear": | |
| head_layers["drop"] = nn.Dropout(drop) | |
| head_layers["proj"] = nn.Linear(prev_chs, embed_dim) | |
| elif proj == "mlp": | |
| head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) | |
| self.head = nn.Sequential(head_layers) | |
| def lock(self, unlocked_groups=0, freeze_bn_stats=False): | |
| """lock modules | |
| Args: | |
| unlocked_groups (int): leave last n layer groups unlocked (default: 0) | |
| """ | |
| if not unlocked_groups: | |
| # lock full model | |
| for param in self.trunk.parameters(): | |
| param.requires_grad = False | |
| if freeze_bn_stats: | |
| freeze_batch_norm_2d(self.trunk) | |
| else: | |
| # NOTE: partial freeze requires latest timm (master) branch and is subject to change | |
| try: | |
| # FIXME import here until API stable and in an official release | |
| from timm.models.helpers import group_modules, group_parameters | |
| except ImportError: | |
| raise RuntimeError( | |
| "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" | |
| ) | |
| matcher = self.trunk.group_matcher() | |
| gparams = group_parameters(self.trunk, matcher) | |
| max_layer_id = max(gparams.keys()) | |
| max_layer_id = max_layer_id - unlocked_groups | |
| for group_idx in range(max_layer_id + 1): | |
| group = gparams[group_idx] | |
| for param in group: | |
| self.trunk.get_parameter(param).requires_grad = False | |
| if freeze_bn_stats: | |
| gmodules = group_modules(self.trunk, matcher, reverse=True) | |
| gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} | |
| freeze_batch_norm_2d(self.trunk, gmodules) | |
| def forward(self, x): | |
| x = self.trunk(x) | |
| x = self.head(x) | |
| return x | |
| class RotAttentionPool2d(nn.Module): | |
| """Attention based 2D feature pooling w/ rotary (relative) pos embedding. | |
| This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. | |
| Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. | |
| https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py | |
| NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from | |
| train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW | |
| """ | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int = None, | |
| embed_dim: int = None, | |
| num_heads: int = 4, | |
| qkv_bias: bool = True, | |
| ): | |
| super().__init__() | |
| embed_dim = embed_dim or in_features | |
| out_features = out_features or in_features | |
| self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) | |
| self.proj = nn.Linear(embed_dim, out_features) | |
| self.num_heads = num_heads | |
| assert embed_dim % num_heads == 0 | |
| self.head_dim = embed_dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.pos_embed = RotaryEmbedding(self.head_dim) | |
| trunc_normal_(self.qkv.weight, std=in_features**-0.5) | |
| nn.init.zeros_(self.qkv.bias) | |
| def forward(self, x): | |
| B, _, H, W = x.shape | |
| N = H * W | |
| x = x.reshape(B, -1, N).permute(0, 2, 1) | |
| x = torch.cat([x.mean(1, keepdim=True), x], dim=1) | |
| x = ( | |
| self.qkv(x) | |
| .reshape(B, N + 1, 3, self.num_heads, self.head_dim) | |
| .permute(2, 0, 3, 1, 4) | |
| ) | |
| q, k, v = x[0], x[1], x[2] | |
| qc, q = q[:, :, :1], q[:, :, 1:] | |
| sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) | |
| q = apply_rot_embed(q, sin_emb, cos_emb) | |
| q = torch.cat([qc, q], dim=2) | |
| kc, k = k[:, :, :1], k[:, :, 1:] | |
| k = apply_rot_embed(k, sin_emb, cos_emb) | |
| k = torch.cat([kc, k], dim=2) | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) | |
| x = self.proj(x) | |
| return x[:, 0] | |
| class AttentionPool2d(nn.Module): | |
| """Attention based 2D feature pooling w/ learned (absolute) pos embedding. | |
| This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. | |
| It was based on impl in CLIP by OpenAI | |
| https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py | |
| NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. | |
| """ | |
| def __init__( | |
| self, | |
| in_features: int, | |
| feat_size: Union[int, Tuple[int, int]], | |
| out_features: int = None, | |
| embed_dim: int = None, | |
| num_heads: int = 4, | |
| qkv_bias: bool = True, | |
| ): | |
| super().__init__() | |
| embed_dim = embed_dim or in_features | |
| out_features = out_features or in_features | |
| assert embed_dim % num_heads == 0 | |
| self.feat_size = to_2tuple(feat_size) | |
| self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) | |
| self.proj = nn.Linear(embed_dim, out_features) | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| spatial_dim = self.feat_size[0] * self.feat_size[1] | |
| self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) | |
| trunc_normal_(self.pos_embed, std=in_features**-0.5) | |
| trunc_normal_(self.qkv.weight, std=in_features**-0.5) | |
| nn.init.zeros_(self.qkv.bias) | |
| def forward(self, x): | |
| B, _, H, W = x.shape | |
| N = H * W | |
| assert self.feat_size[0] == H | |
| assert self.feat_size[1] == W | |
| x = x.reshape(B, -1, N).permute(0, 2, 1) | |
| x = torch.cat([x.mean(1, keepdim=True), x], dim=1) | |
| x = x + self.pos_embed.unsqueeze(0).to(x.dtype) | |
| x = ( | |
| self.qkv(x) | |
| .reshape(B, N + 1, 3, self.num_heads, self.head_dim) | |
| .permute(2, 0, 3, 1, 4) | |
| ) | |
| q, k, v = x[0], x[1], x[2] | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) | |
| x = self.proj(x) | |
| return x[:, 0] | |
| def pixel_freq_bands( | |
| num_bands: int, | |
| max_freq: float = 224.0, | |
| linear_bands: bool = True, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ): | |
| if linear_bands: | |
| bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) | |
| else: | |
| bands = 2 ** torch.linspace( | |
| 0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device | |
| ) | |
| return bands * torch.pi | |
| def inv_freq_bands( | |
| num_bands: int, | |
| temperature: float = 100000.0, | |
| step: int = 2, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| inv_freq = 1.0 / ( | |
| temperature | |
| ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands) | |
| ) | |
| return inv_freq | |
| def build_sincos2d_pos_embed( | |
| feat_shape: List[int], | |
| dim: int = 64, | |
| temperature: float = 10000.0, | |
| reverse_coord: bool = False, | |
| interleave_sin_cos: bool = False, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| feat_shape: | |
| dim: | |
| temperature: | |
| reverse_coord: stack grid order W, H instead of H, W | |
| interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos | |
| dtype: | |
| device: | |
| Returns: | |
| """ | |
| assert ( | |
| dim % 4 == 0 | |
| ), "Embed dimension must be divisible by 4 for sin-cos 2D position embedding" | |
| pos_dim = dim // 4 | |
| bands = inv_freq_bands( | |
| pos_dim, temperature=temperature, step=1, dtype=dtype, device=device | |
| ) | |
| if reverse_coord: | |
| feat_shape = feat_shape[::-1] # stack W, H instead of H, W | |
| grid = ( | |
| torch.stack( | |
| torch.meshgrid( | |
| [torch.arange(s, device=device, dtype=dtype) for s in feat_shape] | |
| ) | |
| ) | |
| .flatten(1) | |
| .transpose(0, 1) | |
| ) | |
| pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) | |
| # FIXME add support for unflattened spatial dim? | |
| stack_dim = ( | |
| 2 if interleave_sin_cos else 1 | |
| ) # stack sin, cos, sin, cos instead of sin sin cos cos | |
| pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) | |
| return pos_emb | |
| def build_fourier_pos_embed( | |
| feat_shape: List[int], | |
| bands: Optional[torch.Tensor] = None, | |
| num_bands: int = 64, | |
| max_res: int = 224, | |
| linear_bands: bool = False, | |
| include_grid: bool = False, | |
| concat_out: bool = True, | |
| in_pixels: bool = True, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ) -> List[torch.Tensor]: | |
| if bands is None: | |
| if in_pixels: | |
| bands = pixel_freq_bands( | |
| num_bands, | |
| float(max_res), | |
| linear_bands=linear_bands, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| else: | |
| bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) | |
| else: | |
| if device is None: | |
| device = bands.device | |
| if dtype is None: | |
| dtype = bands.dtype | |
| if in_pixels: | |
| grid = torch.stack( | |
| torch.meshgrid( | |
| [ | |
| torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=dtype) | |
| for s in feat_shape | |
| ] | |
| ), | |
| dim=-1, | |
| ) | |
| else: | |
| grid = torch.stack( | |
| torch.meshgrid( | |
| [torch.arange(s, device=device, dtype=dtype) for s in feat_shape] | |
| ), | |
| dim=-1, | |
| ) | |
| grid = grid.unsqueeze(-1) | |
| pos = grid * bands | |
| pos_sin, pos_cos = pos.sin(), pos.cos() | |
| out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) | |
| # FIXME torchscript doesn't like multiple return types, probably need to always cat? | |
| if concat_out: | |
| out = torch.cat(out, dim=-1) | |
| return out | |
| class FourierEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| max_res: int = 224, | |
| num_bands: int = 64, | |
| concat_grid=True, | |
| keep_spatial=False, | |
| ): | |
| super().__init__() | |
| self.max_res = max_res | |
| self.num_bands = num_bands | |
| self.concat_grid = concat_grid | |
| self.keep_spatial = keep_spatial | |
| self.register_buffer( | |
| "bands", pixel_freq_bands(max_res, num_bands), persistent=False | |
| ) | |
| def forward(self, x): | |
| B, C = x.shape[:2] | |
| feat_shape = x.shape[2:] | |
| emb = build_fourier_pos_embed( | |
| feat_shape, | |
| self.bands, | |
| include_grid=self.concat_grid, | |
| dtype=x.dtype, | |
| device=x.device, | |
| ) | |
| emb = emb.transpose(-1, -2).flatten(len(feat_shape)) | |
| batch_expand = (B,) + (-1,) * (x.ndim - 1) | |
| # FIXME support nD | |
| if self.keep_spatial: | |
| x = torch.cat( | |
| [x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1 | |
| ) | |
| else: | |
| x = torch.cat( | |
| [x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1 | |
| ) | |
| x = x.reshape(B, feat_shape.numel(), -1) | |
| return x | |
| def rot(x): | |
| return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) | |
| def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): | |
| return x * cos_emb + rot(x) * sin_emb | |
| def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): | |
| if isinstance(x, torch.Tensor): | |
| x = [x] | |
| return [t * cos_emb + rot(t) * sin_emb for t in x] | |
| def apply_rot_embed_split(x: torch.Tensor, emb): | |
| split = emb.shape[-1] // 2 | |
| return x * emb[:, :split] + rot(x) * emb[:, split:] | |
| def build_rotary_pos_embed( | |
| feat_shape: List[int], | |
| bands: Optional[torch.Tensor] = None, | |
| dim: int = 64, | |
| max_freq: float = 224, | |
| linear_bands: bool = False, | |
| dtype: torch.dtype = torch.float32, | |
| device: Optional[torch.device] = None, | |
| ): | |
| """ | |
| NOTE: shape arg should include spatial dim only | |
| """ | |
| feat_shape = torch.Size(feat_shape) | |
| sin_emb, cos_emb = build_fourier_pos_embed( | |
| feat_shape, | |
| bands=bands, | |
| num_bands=dim // 4, | |
| max_res=max_freq, | |
| linear_bands=linear_bands, | |
| concat_out=False, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| N = feat_shape.numel() | |
| sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) | |
| cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) | |
| return sin_emb, cos_emb | |
| class RotaryEmbedding(nn.Module): | |
| """Rotary position embedding | |
| NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not | |
| been well tested, and will likely change. It will be moved to its own file. | |
| The following impl/resources were referenced for this impl: | |
| * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py | |
| * https://blog.eleuther.ai/rotary-embeddings/ | |
| """ | |
| def __init__(self, dim, max_res=224, linear_bands: bool = False): | |
| super().__init__() | |
| self.dim = dim | |
| self.register_buffer( | |
| "bands", | |
| pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), | |
| persistent=False, | |
| ) | |
| def get_embed(self, shape: List[int]): | |
| return build_rotary_pos_embed(shape, self.bands) | |
| def forward(self, x): | |
| # assuming channel-first tensor where spatial dim are >= 2 | |
| sin_emb, cos_emb = self.get_embed(x.shape[2:]) | |
| return apply_rot_embed(x, sin_emb, cos_emb) | |
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): | |
| # Cut & paste from PyTorch official master until it's in a few official releases - RW | |
| # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
| def norm_cdf(x): | |
| # Computes standard normal cumulative distribution function | |
| return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 | |
| if (mean < a - 2 * std) or (mean > b + 2 * std): | |
| warnings.warn( | |
| "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | |
| "The distribution of values may be incorrect.", | |
| stacklevel=2, | |
| ) | |
| with torch.no_grad(): | |
| # Values are generated by using a truncated uniform distribution and | |
| # then using the inverse CDF for the normal distribution. | |
| # Get upper and lower cdf values | |
| l = norm_cdf((a - mean) / std) | |
| u = norm_cdf((b - mean) / std) | |
| # Uniformly fill tensor with values from [l, u], then translate to | |
| # [2l-1, 2u-1]. | |
| tensor.uniform_(2 * l - 1, 2 * u - 1) | |
| # Use inverse cdf transform for normal distribution to get truncated | |
| # standard normal | |
| tensor.erfinv_() | |
| # Transform to proper mean, std | |
| tensor.mul_(std * math.sqrt(2.0)) | |
| tensor.add_(mean) | |
| # Clamp to ensure it's in the proper range | |
| tensor.clamp_(min=a, max=b) | |
| return tensor | |
| def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): | |
| r"""Fills the input Tensor with values drawn from a truncated | |
| normal distribution. The values are effectively drawn from the | |
| normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` | |
| with values outside :math:`[a, b]` redrawn until they are within | |
| the bounds. The method used for generating the random values works | |
| best when :math:`a \leq \text{mean} \leq b`. | |
| Args: | |
| tensor: an n-dimensional `torch.Tensor` | |
| mean: the mean of the normal distribution | |
| std: the standard deviation of the normal distribution | |
| a: the minimum cutoff value | |
| b: the maximum cutoff value | |
| Examples: | |
| >>> w = torch.empty(3, 5) | |
| >>> nn.init.trunc_normal_(w) | |
| """ | |
| return _no_grad_trunc_normal_(tensor, mean, std, a, b) | |