"""Vision Transformer (ViT) in PyTorch A PyTorch implement of Vision Transformers as described in: 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - https://arxiv.org/abs/2106.10270 `FlexiViT: One Model for All Patch Sizes` - https://arxiv.org/abs/2212.08013 The official jax code is released and available at * https://github.com/google-research/vision_transformer * https://github.com/google-research/big_vision Acknowledgments: * The paper authors for releasing code and weights, thanks! * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Bert reference code checks against Huggingface Transformers and Tensorflow Bert Hacked together by / Copyright 2020, Ross Wightman """ import logging import math from collections import OrderedDict from functools import partial from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List try: from typing import Literal except ImportError: from typing_extensions import Literal import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from torch.jit import Final from timm.data import ( IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ) from timm.layers import ( PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, get_act_layer, get_norm_layer, LayerType, ) from timm.models._builder import build_model_with_cfg from timm.models._features import feature_take_indices from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv from timm.models._registry import ( generate_default_cfgs, register_model, register_model_deprecations, ) __all__ = ["VisionTransformer"] # model_registry will add each entrypoint fn to this _logger = logging.getLogger(__name__) def init_1d_freqs(dim: int, end: int, theta: float = 10000.0): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (int): Dimension of the frequency tensor. end (int): End index for precomputing frequencies. theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis def init_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True): freqs_x = [] freqs_y = [] mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) for i in range(num_heads): angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1) fx = torch.cat( [mag * torch.cos(angles), mag * torch.cos(torch.pi / 2 + angles)], dim=-1 ) fy = torch.cat( [mag * torch.sin(angles), mag * torch.sin(torch.pi / 2 + angles)], dim=-1 ) freqs_x.append(fx) freqs_y.append(fy) freqs_x = torch.stack(freqs_x, dim=0) freqs_y = torch.stack(freqs_y, dim=0) freqs = torch.stack([freqs_x, freqs_y], dim=0) return freqs def init_t_xy(end_x: int, end_y: int): t = torch.arange(end_x * end_y, dtype=torch.float32) t_x = (t % end_x).float() t_y = torch.div(t, end_x, rounding_mode="floor").float() return t_x, t_y def compute_mixed_cis( freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int ): N = t_x.shape[0] # No float 16 for this range with torch.cuda.amp.autocast(enabled=False): freqs_x = ( (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2)) .view(N, num_heads, -1) .permute(1, 0, 2) ) freqs_y = ( (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)) .view(N, num_heads, -1) .permute(1, 0, 2) ) freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y) return freqs_cis def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 100.0): freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) t_x, t_y = init_t_xy(end_x, end_y) freqs_x = torch.outer(t_x, freqs_x) freqs_y = torch.outer(t_y, freqs_y) freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim if freqs_cis.shape == (x.shape[-2], x.shape[-1]): shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]): shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) class Attention(nn.Module): fused_attn: Final[bool] def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, norm_layer: nn.Module = nn.LayerNorm, **kwargs, ) -> None: super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor: B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, self.head_dim) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, attn_mask=attn_mask, ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) if attn_mask is not None: attn += attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class RoPEAttention(Attention): """Multi-head Attention block with rotary position embeddings.""" def __init__( self, *args, num_prefix_tokens=1, num_latent_tokens=32, num_image_tokens=256, rope_theta=10.0, rope_mixed=True, **kwargs, ): super().__init__(*args, **kwargs) self.rope_mixed = rope_mixed self.num_prefix_tokens = num_prefix_tokens self.num_latent_tokens = num_latent_tokens self.num_image_tokens = num_image_tokens self.num_axis_tokens = int(num_image_tokens**0.5) if self.rope_mixed: self.compute_cis = partial(compute_mixed_cis, num_heads=self.num_heads) freqs = init_2d_freqs( dim=self.head_dim, num_heads=self.num_heads, theta=rope_theta, rotate=True, ).view(2, -1) self.freqs = nn.Parameter(freqs, requires_grad=True) t_x, t_y = init_t_xy(end_x=self.num_axis_tokens, end_y=self.num_axis_tokens) self.register_buffer("freqs_t_x", t_x) self.register_buffer("freqs_t_y", t_y) else: self.compute_cis = partial( compute_axial_cis, dim=self.head_dim, theta=rope_theta ) freqs_cis = self.compute_cis( end_x=self.num_axis_tokens, end_y=self.num_axis_tokens ) self.freqs_cis = freqs_cis # get pre-compted 1d rope freqs_1d = init_1d_freqs(dim=self.head_dim, end=self.num_latent_tokens) self.freqs_1d = nn.Parameter(freqs_1d, requires_grad=True) def forward(self, x, attn_mask=None): B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv[0], qkv[1], qkv[2] ###### Apply rotary position embedding w = h = math.sqrt(x.shape[1] - 1) if self.rope_mixed: t_x, t_y = self.freqs_t_x, self.freqs_t_y if ( self.freqs_t_x.shape[0] != x.shape[1] - self.num_prefix_tokens - self.num_latent_tokens ): t_x, t_y = init_t_xy(end_x=w, end_y=h) t_x, t_y = t_x.to(x.device), t_y.to(x.device) freqs_cis = self.compute_cis(self.freqs, t_x, t_y) else: freqs_cis = self.freqs_cis if ( self.freqs_cis.shape[0] != x.shape[1] - self.num_prefix_tokens - self.num_latent_tokens ): freqs_cis = self.compute_cis(end_x=w, end_y=h) freqs_cis = freqs_cis.to(x.device) # apply rotary position embedding to image tokens dtype = x.dtype with torch.cuda.amp.autocast(enabled=False): ( q[:, :, self.num_prefix_tokens : -self.num_latent_tokens], k[:, :, self.num_prefix_tokens : -self.num_latent_tokens], ) = apply_rotary_emb( q[:, :, self.num_prefix_tokens : -self.num_latent_tokens], k[:, :, self.num_prefix_tokens : -self.num_latent_tokens], freqs_cis=freqs_cis, ) q[:, :, -self.num_latent_tokens :], k[:, :, -self.num_latent_tokens :] = ( apply_rotary_emb( q[:, :, -self.num_latent_tokens :], k[:, :, -self.num_latent_tokens :], freqs_cis=self.freqs_1d, ) ) q, k = q.to(dtype), k.to(dtype) ######### attn = (q * self.scale) @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class LayerScale(nn.Module): def __init__( self, dim: int, init_values: float = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_norm: bool = False, proj_drop: float = 0.0, attn_drop: float = 0.0, init_values: Optional[float] = None, drop_path: float = 0.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.LayerNorm, mlp_layer: nn.Module = Mlp, attn_layer: nn.Module = Attention, ) -> None: super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_layer( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, ) self.ls1 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = mlp_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, ) self.ls2 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> torch.Tensor: x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x class ResPostBlock(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_norm: bool = False, proj_drop: float = 0.0, attn_drop: float = 0.0, init_values: Optional[float] = None, drop_path: float = 0.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.LayerNorm, mlp_layer: nn.Module = Mlp, ) -> None: super().__init__() self.init_values = init_values self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, ) self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.mlp = mlp_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, ) self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.init_weights() def init_weights(self) -> None: # NOTE this init overrides that base model init with specific changes for the block type if self.init_values is not None: nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.drop_path1(self.norm1(self.attn(x))) x = x + self.drop_path2(self.norm2(self.mlp(x))) return x class ParallelScalingBlock(nn.Module): """Parallel ViT block (MLP & Attention in parallel) Based on: 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442 """ fused_attn: Final[bool] def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_norm: bool = False, proj_drop: float = 0.0, attn_drop: float = 0.0, init_values: Optional[float] = None, drop_path: float = 0.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.LayerNorm, mlp_layer: Optional[nn.Module] = None, ) -> None: super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.fused_attn = use_fused_attn() mlp_hidden_dim = int(mlp_ratio * dim) in_proj_out_dim = mlp_hidden_dim + 3 * dim self.in_norm = norm_layer(dim) self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias) self.in_split = [mlp_hidden_dim] + [dim] * 3 if qkv_bias: self.register_buffer("qkv_bias", None) self.register_parameter("mlp_bias", None) else: self.register_buffer("qkv_bias", torch.zeros(3 * dim), persistent=False) self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim)) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.attn_out_proj = nn.Linear(dim, dim) self.mlp_drop = nn.Dropout(proj_drop) self.mlp_act = act_layer() self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim) self.ls = ( LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape # Combined MLP fc1 & qkv projections y = self.in_norm(x) if self.mlp_bias is not None: # Concat constant zero-bias for qkv w/ trainable mlp_bias. # Appears faster than adding to x_mlp separately y = F.linear( y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)) ) else: y = self.in_proj(y) x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1) # Dot product attention w/ qk norm q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) if self.fused_attn: x_attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x_attn = attn @ v x_attn = x_attn.transpose(1, 2).reshape(B, N, C) x_attn = self.attn_out_proj(x_attn) # MLP activation, dropout, fc2 x_mlp = self.mlp_act(x_mlp) x_mlp = self.mlp_drop(x_mlp) x_mlp = self.mlp_out_proj(x_mlp) # Add residual w/ drop path & layer scale applied y = self.drop_path(self.ls(x_attn + x_mlp)) x = x + y return x class ParallelThingsBlock(nn.Module): """Parallel ViT block (N parallel attention followed by N parallel MLP) Based on: `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 """ def __init__( self, dim: int, num_heads: int, num_parallel: int = 2, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_norm: bool = False, init_values: Optional[float] = None, proj_drop: float = 0.0, attn_drop: float = 0.0, drop_path: float = 0.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.LayerNorm, mlp_layer: nn.Module = Mlp, ) -> None: super().__init__() self.num_parallel = num_parallel self.attns = nn.ModuleList() self.ffns = nn.ModuleList() for _ in range(num_parallel): self.attns.append( nn.Sequential( OrderedDict( [ ("norm", norm_layer(dim)), ( "attn", Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, ), ), ( "ls", ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ), ), ( "drop_path", ( DropPath(drop_path) if drop_path > 0.0 else nn.Identity() ), ), ] ) ) ) self.ffns.append( nn.Sequential( OrderedDict( [ ("norm", norm_layer(dim)), ( "mlp", mlp_layer( dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, ), ), ( "ls", ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ), ), ( "drop_path", ( DropPath(drop_path) if drop_path > 0.0 else nn.Identity() ), ), ] ) ) ) def _forward_jit(self, x: torch.Tensor) -> torch.Tensor: x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) return x @torch.jit.ignore def _forward(self, x: torch.Tensor) -> torch.Tensor: x = x + sum(attn(x) for attn in self.attns) x = x + sum(ffn(x) for ffn in self.ffns) return x def forward(self, x: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting() or torch.jit.is_tracing(): return self._forward_jit(x) else: return self._forward(x) def global_pool_nlc( x: torch.Tensor, pool_type: str = "token", num_prefix_tokens: int = 1, reduce_include_prefix: bool = False, ): if not pool_type: return x if pool_type == "token": x = x[:, 0] # class token else: x = x if reduce_include_prefix else x[:, num_prefix_tokens:] if pool_type == "avg": x = x.mean(dim=1) elif pool_type == "avgmax": x = 0.5 * (x.amax(dim=1) + x.mean(dim=1)) elif pool_type == "max": x = x.amax(dim=1) else: assert not pool_type, f"Unknown pool type {pool_type}" return x class VisionTransformer(nn.Module): """Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ dynamic_img_size: Final[bool] def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, global_pool: Literal["", "avg", "avgmax", "max", "token", "map"] = "token", embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_norm: bool = False, init_values: Optional[float] = None, class_token: bool = True, pos_embed: str = "learn", no_embed_class: bool = False, reg_tokens: int = 0, pre_norm: bool = False, fc_norm: Optional[bool] = None, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, drop_rate: float = 0.0, pos_drop_rate: float = 0.0, patch_drop_rate: float = 0.0, proj_drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", fix_init: bool = False, embed_layer: Callable = PatchEmbed, norm_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None, block_fn: Type[nn.Module] = Block, mlp_layer: Type[nn.Module] = Mlp, attn_layer: Type[nn.Module] = Attention, num_latent_tokens: int = 32, ) -> None: """ Args: img_size: Input image size. patch_size: Patch size. in_chans: Number of image input channels. num_classes: Number of classes for classification head. global_pool: Type of global pooling for final sequence (default: 'token'). embed_dim: Transformer embedding dimension. depth: Depth of transformer. num_heads: Number of attention heads. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: Enable bias for qkv projections if True. init_values: Layer-scale init values (layer-scale enabled if not None). class_token: Use class token. no_embed_class: Don't include position embeddings for class (or reg) tokens. reg_tokens: Number of register tokens. fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. drop_rate: Head dropout rate. pos_drop_rate: Position embedding dropout rate. attn_drop_rate: Attention dropout rate. drop_path_rate: Stochastic depth rate. weight_init: Weight initialization scheme. fix_init: Apply weight initialization fix (scaling w/ layer index). embed_layer: Patch embedding layer. norm_layer: Normalization layer. act_layer: MLP activation layer. block_fn: Transformer block layer. """ super().__init__() assert global_pool in ("", "avg", "avgmax", "max", "token", "map") assert class_token or global_pool != "token" assert pos_embed in ("", "none", "learn") use_fc_norm = ( global_pool in ("avg", "avgmax", "max") if fc_norm is None else fc_norm ) norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) act_layer = get_act_layer(act_layer) or nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.head_hidden_size = self.embed_dim = ( embed_dim # for consistency with other models ) self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens += reg_tokens self.num_reg_tokens = reg_tokens self.has_class_token = class_token self.no_embed_class = ( no_embed_class # don't embed prefix positions (includes reg) ) self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False embed_args = {} if dynamic_img_size: # flatten deferred until after pos embed embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) dynamic_img_pad=dynamic_img_pad, **embed_args, ) num_patches = self.patch_embed.num_patches reduction = ( self.patch_embed.feat_ratio() if hasattr(self.patch_embed, "feat_ratio") else patch_size ) self.cls_token = ( nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None ) self.reg_token = ( nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None ) embed_len = ( num_patches if no_embed_class else num_patches + self.num_prefix_tokens ) if not pos_embed or pos_embed == "none": self.pos_embed = None else: self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) self.pos_drop = nn.Dropout(p=pos_drop_rate) if patch_drop_rate > 0: self.patch_drop = PatchDropout( patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens, ) else: self.patch_drop = nn.Identity() self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, depth) ] # stochastic depth decay rule self.blocks = nn.Sequential( *[ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_norm=qk_norm, init_values=init_values, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, mlp_layer=mlp_layer, attn_layer=partial( attn_layer, num_prefix_tokens=self.num_prefix_tokens, num_latent_tokens=num_latent_tokens, patch_size=patch_size, ), ) for i in range(depth) ] ) self.feature_info = [ dict(module=f"blocks.{i}", num_chs=embed_dim, reduction=reduction) for i in range(depth) ] self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head if global_pool == "map": self.attn_pool = AttentionPoolLatent( self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, ) else: self.attn_pool = None self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.head_drop = nn.Dropout(drop_rate) self.head = ( nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) if weight_init != "skip": self.init_weights(weight_init) if fix_init: self.fix_init_weight() def fix_init_weight(self): def rescale(param, _layer_id): param.div_(math.sqrt(2.0 * _layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def init_weights(self, mode: str = "") -> None: assert mode in ("jax", "jax_nlhb", "moco", "") head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=0.02) if self.cls_token is not None: nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m: nn.Module) -> None: # this fn left here for compat with downstream users init_weights_vit_timm(m) @torch.jit.ignore() def load_pretrained(self, checkpoint_path: str, prefix: str = "") -> None: _load_weights(self, checkpoint_path, prefix) @torch.jit.ignore def no_weight_decay(self) -> Set: return {"pos_embed", "cls_token", "dist_token"} @torch.jit.ignore def group_matcher(self, coarse: bool = False) -> Dict: return dict( stem=r"^cls_token|pos_embed|patch_embed", # stem and embed blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], ) @torch.jit.ignore def set_grad_checkpointing(self, enable: bool = True) -> None: self.grad_checkpointing = enable if hasattr(self.patch_embed, "set_grad_checkpointing"): self.patch_embed.set_grad_checkpointing(enable) @torch.jit.ignore def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ("", "avg", "avgmax", "max", "token", "map") if global_pool == "map" and self.attn_pool is None: assert ( False ), "Cannot currently add attention pooling in reset_classifier()." elif global_pool != "map " and self.attn_pool is not None: self.attn_pool = None # remove attention pooling self.global_pool = global_pool self.head = ( nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() ) def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: if self.pos_embed is None: return x.view(x.shape[0], -1, x.shape[-1]) if self.dynamic_img_size or len(x.shape) == 4: B, H, W, C = x.shape pos_embed = resample_abs_pos_embed( self.pos_embed, (H, W), num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, ) x = x.view(B, -1, C) else: pos_embed = self.pos_embed to_cat = [] if self.cls_token is not None: to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) if self.reg_token is not None: to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) if self.no_embed_class: # deit-3, updated JAX (big vision) # position embedding does not overlap with class token, add then concat x = x + pos_embed if to_cat: x = torch.cat(to_cat + [x], dim=1) else: # original timm, JAX, and deit vit impl # pos_embed has entry for class token, concat then add if to_cat: x = torch.cat(to_cat + [x], dim=1) x = x + pos_embed return self.pos_drop(x) def forward_intermediates( self, x: torch.Tensor, indices: Optional[Union[int, List[int], Tuple[int]]] = None, return_prefix_tokens: bool = False, norm: bool = False, stop_early: bool = False, output_fmt: str = "NCHW", intermediates_only: bool = False, ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: """Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence return_prefix_tokens: Return both prefix and spatial intermediate tokens norm: Apply norm layer to all intermediates stop_early: Stop iterating over blocks when last desired intermediate hit output_fmt: Shape of intermediate feature outputs intermediates_only: Only return intermediate features Returns: """ assert output_fmt in ( "NCHW", "NLC", ), "Output format must be one of NCHW or NLC." reshape = output_fmt == "NCHW" intermediates = [] take_indices, max_index = feature_take_indices(len(self.blocks), indices) # forward pass B, _, height, width = x.shape x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) if ( torch.jit.is_scripting() or not stop_early ): # can't slice blocks in torchscript blocks = self.blocks else: blocks = self.blocks[: max_index + 1] for i, blk in enumerate(blocks): x = blk(x) if i in take_indices: # normalize intermediates with final norm layer if enabled intermediates.append(self.norm(x) if norm else x) # process intermediates if self.num_prefix_tokens: # split prefix (e.g. class, distill) and spatial feature tokens prefix_tokens = [y[:, 0 : self.num_prefix_tokens] for y in intermediates] intermediates = [y[:, self.num_prefix_tokens :] for y in intermediates] if reshape: # reshape to BCHW output format H, W = self.patch_embed.dynamic_feat_size((height, width)) intermediates = [ y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates ] if not torch.jit.is_scripting() and return_prefix_tokens: # return_prefix not support in torchscript due to poor type handling intermediates = list(zip(intermediates, prefix_tokens)) if intermediates_only: return intermediates x = self.norm(x) return x, intermediates def prune_intermediate_layers( self, indices: Union[int, List[int], Tuple[int]] = 1, prune_norm: bool = False, prune_head: bool = True, ): """Prune layers not required for specified intermediates.""" take_indices, max_index = feature_take_indices(len(self.blocks), indices) self.blocks = self.blocks[: max_index + 1] # truncate blocks if prune_norm: self.norm = nn.Identity() if prune_head: self.fc_norm = nn.Identity() self.reset_classifier(0, "") return take_indices def get_intermediate_layers( self, x: torch.Tensor, n: Union[int, List[int], Tuple[int]] = 1, reshape: bool = False, return_prefix_tokens: bool = False, norm: bool = False, ) -> List[torch.Tensor]: """Intermediate layer accessor inspired by DINO / DINOv2 interface. NOTE: This API is for backwards compat, favour using forward_intermediates() directly. """ return self.forward_intermediates( x, n, return_prefix_tokens=return_prefix_tokens, norm=norm, output_fmt="NCHW" if reshape else "NLC", intermediates_only=True, ) def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) x = self.norm(x) return x def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor: if self.attn_pool is not None: x = self.attn_pool(x) return x pool_type = self.global_pool if pool_type is None else pool_type x = global_pool_nlc( x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens ) return x def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: x = self.pool(x) x = self.fc_norm(x) x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) return x def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, "init_weights"): module.init_weights() def init_weights_vit_jax( module: nn.Module, name: str = "", head_bias: float = 0.0 ) -> None: """ViT weight initialization, matching JAX (Flax) impl""" if isinstance(module, nn.Linear): if name.startswith("head"): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: ( nn.init.normal_(module.bias, std=1e-6) if "mlp" in name else nn.init.zeros_(module.bias) ) elif isinstance(module, nn.Conv2d): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, "init_weights"): module.init_weights() def init_weights_vit_moco(module: nn.Module, name: str = "") -> None: """ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed""" if isinstance(module, nn.Linear): if "qkv" in name: # treat the weights of Q, K, V separately val = math.sqrt( 6.0 / float(module.weight.shape[0] // 3 + module.weight.shape[1]) ) nn.init.uniform_(module.weight, -val, val) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif hasattr(module, "init_weights"): module.init_weights() def get_init_weights_vit(mode: str = "jax", head_bias: float = 0.0) -> Callable: if "jax" in mode: return partial(init_weights_vit_jax, head_bias=head_bias) elif "moco" in mode: return init_weights_vit_moco else: return init_weights_vit_timm def resize_pos_embed( posemb: torch.Tensor, posemb_new: torch.Tensor, num_prefix_tokens: int = 1, gs_new: Tuple[int, int] = (), interpolation: str = "bicubic", antialias: bool = False, ) -> torch.Tensor: """Rescale the grid of position embeddings when loading from state_dict. *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed """ ntok_new = posemb_new.shape[1] - num_prefix_tokens ntok_old = posemb.shape[1] - num_prefix_tokens gs_old = [int(math.sqrt(ntok_old))] * 2 if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 return resample_abs_pos_embed( posemb, gs_new, gs_old, num_prefix_tokens=num_prefix_tokens, interpolation=interpolation, antialias=antialias, verbose=True, ) @torch.no_grad() def _load_weights( model: VisionTransformer, checkpoint_path: str, prefix: str = "" ) -> None: """Load weights from .npz checkpoints for official Google Brain Flax implementation""" import numpy as np def _n2p(w, t=True, idx=None): if idx is not None: w = w[idx] if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: if w.ndim == 4: w = w.transpose([3, 2, 0, 1]) elif w.ndim == 3: w = w.transpose([2, 0, 1]) elif w.ndim == 2: w = w.transpose([1, 0]) return torch.from_numpy(w) w = np.load(checkpoint_path) interpolation = "bilinear" antialias = False big_vision = False if not prefix: if "opt/target/embedding/kernel" in w: prefix = "opt/target/" elif "params/embedding/kernel" in w: prefix = "params/" big_vision = True elif "params/img/embedding/kernel" in w: prefix = "params/img/" big_vision = True if hasattr(model.patch_embed, "backbone"): # hybrid backbone = model.patch_embed.backbone stem_only = not hasattr(backbone, "stem") stem = backbone if stem_only else backbone.stem stem.conv.weight.copy_( adapt_input_conv( stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"]) ) ) stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"])) stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"])) if not stem_only: for i, stage in enumerate(backbone.stages): for j, block in enumerate(stage.blocks): bp = f"{prefix}block{i + 1}/unit{j + 1}/" for r in range(3): getattr(block, f"conv{r + 1}").weight.copy_( _n2p(w[f"{bp}conv{r + 1}/kernel"]) ) getattr(block, f"norm{r + 1}").weight.copy_( _n2p(w[f"{bp}gn{r + 1}/scale"]) ) getattr(block, f"norm{r + 1}").bias.copy_( _n2p(w[f"{bp}gn{r + 1}/bias"]) ) if block.downsample is not None: block.downsample.conv.weight.copy_( _n2p(w[f"{bp}conv_proj/kernel"]) ) block.downsample.norm.weight.copy_( _n2p(w[f"{bp}gn_proj/scale"]) ) block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"])) embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"]) else: embed_conv_w = adapt_input_conv( model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"]) ) if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]: embed_conv_w = resample_patch_embed( embed_conv_w, model.patch_embed.proj.weight.shape[-2:], interpolation=interpolation, antialias=antialias, verbose=True, ) model.patch_embed.proj.weight.copy_(embed_conv_w) model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"])) if model.cls_token is not None: model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False)) if big_vision: pos_embed_w = _n2p(w[f"{prefix}pos_embedding"], t=False) else: pos_embed_w = _n2p( w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False ) if pos_embed_w.shape != model.pos_embed.shape: old_shape = pos_embed_w.shape num_prefix_tokens = ( 0 if getattr(model, "no_embed_class", False) else getattr(model, "num_prefix_tokens", 1) ) pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w, new_size=model.patch_embed.grid_size, num_prefix_tokens=num_prefix_tokens, interpolation=interpolation, antialias=antialias, verbose=True, ) model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"])) model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"])) if ( isinstance(model.head, nn.Linear) and f"{prefix}head/bias" in w and model.head.bias.shape[0] == w[f"{prefix}head/bias"].shape[-1] ): model.head.weight.copy_(_n2p(w[f"{prefix}head/kernel"])) model.head.bias.copy_(_n2p(w[f"{prefix}head/bias"])) # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) if model.attn_pool is not None: block_prefix = f"{prefix}MAPHead_0/" mha_prefix = block_prefix + f"MultiHeadDotProductAttention_0/" model.attn_pool.latent.copy_(_n2p(w[f"{block_prefix}probe"], t=False)) model.attn_pool.kv.weight.copy_( torch.cat( [ _n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T for n in ("key", "value") ] ) ) model.attn_pool.kv.bias.copy_( torch.cat( [ _n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1) for n in ("key", "value") ] ) ) model.attn_pool.q.weight.copy_( _n2p(w[f"{mha_prefix}query/kernel"], t=False).flatten(1).T ) model.attn_pool.q.bias.copy_( _n2p(w[f"{mha_prefix}query/bias"], t=False).reshape(-1) ) model.attn_pool.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1)) model.attn_pool.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"])) model.attn_pool.norm.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"])) model.attn_pool.norm.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"])) for r in range(2): getattr(model.attn_pool.mlp, f"fc{r + 1}").weight.copy_( _n2p(w[f"{block_prefix}MlpBlock_0/Dense_{r}/kernel"]) ) getattr(model.attn_pool.mlp, f"fc{r + 1}").bias.copy_( _n2p(w[f"{block_prefix}MlpBlock_0/Dense_{r}/bias"]) ) mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) for i, block in enumerate(model.blocks.children()): if f"{prefix}Transformer/encoderblock/LayerNorm_0/scale" in w: block_prefix = f"{prefix}Transformer/encoderblock/" idx = i else: block_prefix = f"{prefix}Transformer/encoderblock_{i}/" idx = None mha_prefix = block_prefix + f"MultiHeadDotProductAttention_{mha_sub}/" block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"], idx=idx)) block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"], idx=idx)) block.attn.qkv.weight.copy_( torch.cat( [ _n2p(w[f"{mha_prefix}{n}/kernel"], t=False, idx=idx).flatten(1).T for n in ("query", "key", "value") ] ) ) block.attn.qkv.bias.copy_( torch.cat( [ _n2p(w[f"{mha_prefix}{n}/bias"], t=False, idx=idx).reshape(-1) for n in ("query", "key", "value") ] ) ) block.attn.proj.weight.copy_( _n2p(w[f"{mha_prefix}out/kernel"], idx=idx).flatten(1) ) block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"], idx=idx)) block.norm2.weight.copy_( _n2p(w[f"{block_prefix}LayerNorm_{ln1_sub}/scale"], idx=idx) ) block.norm2.bias.copy_( _n2p(w[f"{block_prefix}LayerNorm_{ln1_sub}/bias"], idx=idx) ) for r in range(2): getattr(block.mlp, f"fc{r + 1}").weight.copy_( _n2p(w[f"{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel"], idx=idx) ) getattr(block.mlp, f"fc{r + 1}").bias.copy_( _n2p(w[f"{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias"], idx=idx) ) def _convert_openai_clip( state_dict: Dict[str, torch.Tensor], model: VisionTransformer, prefix: str = "visual.", ) -> Dict[str, torch.Tensor]: out_dict = {} swaps = [ ("conv1", "patch_embed.proj"), ("positional_embedding", "pos_embed"), ("transformer.resblocks.", "blocks."), ("ln_pre", "norm_pre"), ("ln_post", "norm"), ("ln_", "norm"), ("in_proj_", "qkv."), ("out_proj", "proj"), ("mlp.c_fc", "mlp.fc1"), ("mlp.c_proj", "mlp.fc2"), ] for k, v in state_dict.items(): if not k.startswith(prefix): continue k = k.replace(prefix, "") for sp in swaps: k = k.replace(sp[0], sp[1]) if k == "proj": k = "head.weight" v = v.transpose(0, 1) out_dict["head.bias"] = torch.zeros(v.shape[0]) elif k == "class_embedding": k = "cls_token" v = v.unsqueeze(0).unsqueeze(1) elif k == "pos_embed": v = v.unsqueeze(0) out_dict[k] = v return out_dict def _convert_dinov2( state_dict: Dict[str, torch.Tensor], model: VisionTransformer, ) -> Dict[str, torch.Tensor]: import re out_dict = {} state_dict.pop("mask_token", None) if "register_tokens" in state_dict: # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed) out_dict["reg_token"] = state_dict.pop("register_tokens") out_dict["cls_token"] = ( state_dict.pop("cls_token") + state_dict["pos_embed"][:, 0] ) out_dict["pos_embed"] = state_dict.pop("pos_embed")[:, 1:] for k, v in state_dict.items(): if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): out_dict[k.replace("w12", "fc1")] = v continue elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): out_dict[k.replace("w3", "fc2")] = v continue out_dict[k] = v return out_dict def checkpoint_filter_fn( state_dict: Dict[str, torch.Tensor], model: VisionTransformer, adapt_layer_scale: bool = False, interpolation: str = "bicubic", antialias: bool = True, ) -> Dict[str, torch.Tensor]: """convert patch embedding weight from manual patchify + linear proj to conv""" import re out_dict = {} state_dict = state_dict.get("model", state_dict) state_dict = state_dict.get("state_dict", state_dict) prefix = "" if "visual.class_embedding" in state_dict: state_dict = _convert_openai_clip(state_dict, model) elif "module.visual.class_embedding" in state_dict: state_dict = _convert_openai_clip(state_dict, model, prefix="module.visual.") elif "mask_token" in state_dict: state_dict = _convert_dinov2(state_dict, model) elif "encoder" in state_dict: # IJEPA, vit in an 'encoder' submodule state_dict = state_dict["encoder"] prefix = "module." elif ( "visual.trunk.pos_embed" in state_dict or "visual.trunk.blocks.0.norm1.weight" in state_dict ): # OpenCLIP model with timm vision encoder prefix = "visual.trunk." if "visual.head.proj.weight" in state_dict and isinstance( model.head, nn.Linear ): # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) out_dict["head.weight"] = state_dict["visual.head.proj.weight"] out_dict["head.bias"] = torch.zeros( state_dict["visual.head.proj.weight"].shape[0] ) if prefix: # filter on & remove prefix string from keys state_dict = { k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix) } for k, v in state_dict.items(): if "patch_embed.proj.weight" in k: O, I, H, W = model.patch_embed.proj.weight.shape if len(v.shape) < 4: # For old models that I trained prior to conv based patchification O, I, H, W = model.patch_embed.proj.weight.shape v = v.reshape(O, -1, H, W) if v.shape[-1] != W or v.shape[-2] != H: v = resample_patch_embed( v, (H, W), interpolation=interpolation, antialias=antialias, verbose=True, ) elif k == "pos_embed" and v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights num_prefix_tokens = ( 0 if getattr(model, "no_embed_class", False) else getattr(model, "num_prefix_tokens", 1) ) v = resample_abs_pos_embed( v, new_size=model.patch_embed.grid_size, num_prefix_tokens=num_prefix_tokens, interpolation=interpolation, antialias=antialias, verbose=True, ) elif adapt_layer_scale and "gamma_" in k: # remap layer-scale gamma into sub-module (deit3 models) k = re.sub(r"gamma_([0-9])", r"ls\1.gamma", k) elif "pre_logits" in k: # NOTE representation layer removed as not used in latest 21k/1k pretrained weights continue out_dict[k] = v return out_dict def _cfg(url: str = "", **kwargs) -> Dict[str, Any]: return { "url": url, "num_classes": 1000, "input_size": (3, 224, 224), "pool_size": None, "crop_pct": 0.9, "interpolation": "bicubic", "fixed_input_size": True, "mean": IMAGENET_INCEPTION_MEAN, "std": IMAGENET_INCEPTION_STD, "first_conv": "patch_embed.proj", "classifier": "head", **kwargs, } default_cfgs = { # re-finetuned augreg 21k FT on in1k weights "vit_base_patch16_224.augreg2_in21k_ft_in1k": _cfg(hf_hub_id="timm/"), "vit_base_patch16_384.augreg2_in21k_ft_in1k": _cfg(), "vit_base_patch8_224.augreg2_in21k_ft_in1k": _cfg(hf_hub_id="timm/"), # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k "vit_tiny_patch16_224.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz", hf_hub_id="timm/", custom_load=True, ), "vit_tiny_patch16_384.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz", hf_hub_id="timm/", custom_load=True, input_size=(3, 384, 384), crop_pct=1.0, ), "vit_small_patch32_224.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz", hf_hub_id="timm/", custom_load=True, ), "vit_small_patch32_384.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz", hf_hub_id="timm/", custom_load=True, input_size=(3, 384, 384), crop_pct=1.0, ), "vit_small_patch16_224.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz", hf_hub_id="timm/", custom_load=True, ), "vit_small_patch16_384.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz", hf_hub_id="timm/", custom_load=True, input_size=(3, 384, 384), crop_pct=1.0, ), "vit_base_patch32_224.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz", hf_hub_id="timm/", custom_load=True, ), "vit_base_patch32_384.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz", hf_hub_id="timm/", custom_load=True, input_size=(3, 384, 384), crop_pct=1.0, ), "vit_base_patch16_224.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz", hf_hub_id="timm/", custom_load=True, ), "vit_base_patch16_384.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz", hf_hub_id="timm/", custom_load=True, input_size=(3, 384, 384), crop_pct=1.0, ), "vit_base_patch8_224.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz", hf_hub_id="timm/", custom_load=True, ), "vit_large_patch16_224.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz", hf_hub_id="timm/", custom_load=True, ), "vit_large_patch16_384.augreg_in21k_ft_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz", hf_hub_id="timm/", custom_load=True, input_size=(3, 384, 384), crop_pct=1.0, ), # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k "vit_base_patch16_224.orig_in21k_ft_in1k": _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth", hf_hub_id="timm/", ), "vit_base_patch16_384.orig_in21k_ft_in1k": _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth", hf_hub_id="timm/", input_size=(3, 384, 384), crop_pct=1.0, ), "vit_large_patch32_384.orig_in21k_ft_in1k": _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth", hf_hub_id="timm/", input_size=(3, 384, 384), crop_pct=1.0, ), # How to train your ViT (augreg) weights trained on in1k only "vit_small_patch16_224.augreg_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz", hf_hub_id="timm/", custom_load=True, ), "vit_small_patch16_384.augreg_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz", hf_hub_id="timm/", custom_load=True, input_size=(3, 384, 384), crop_pct=1.0, ), "vit_base_patch32_224.augreg_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz", hf_hub_id="timm/", custom_load=True, ), "vit_base_patch32_384.augreg_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz", hf_hub_id="timm/", custom_load=True, input_size=(3, 384, 384), crop_pct=1.0, ), "vit_base_patch16_224.augreg_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz", hf_hub_id="timm/", custom_load=True, ), "vit_base_patch16_384.augreg_in1k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz", hf_hub_id="timm/", custom_load=True, input_size=(3, 384, 384), crop_pct=1.0, ), "vit_large_patch14_224.untrained": _cfg(url=""), "vit_huge_patch14_224.untrained": _cfg(url=""), "vit_giant_patch14_224.untrained": _cfg(url=""), "vit_gigantic_patch14_224.untrained": _cfg(url=""), # patch models, imagenet21k (weights from official Google JAX impl), classifier not valid "vit_base_patch32_224.orig_in21k": _cfg( # url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth', hf_hub_id="timm/", num_classes=0, ), "vit_base_patch16_224.orig_in21k": _cfg( # url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', hf_hub_id="timm/", num_classes=0, ), "vit_large_patch32_224.orig_in21k": _cfg( # url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', hf_hub_id="timm/", num_classes=0, ), "vit_large_patch16_224.orig_in21k": _cfg( # url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth', hf_hub_id="timm/", num_classes=0, ), "vit_huge_patch14_224.orig_in21k": _cfg(hf_hub_id="timm/", num_classes=0), # How to train your ViT (augreg) weights, pretrained on in21k "vit_tiny_patch16_224.augreg_in21k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz", hf_hub_id="timm/", custom_load=True, num_classes=21843, ), "vit_small_patch32_224.augreg_in21k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", hf_hub_id="timm/", custom_load=True, num_classes=21843, ), "vit_small_patch16_224.augreg_in21k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz", hf_hub_id="timm/", custom_load=True, num_classes=21843, ), "vit_base_patch32_224.augreg_in21k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz", hf_hub_id="timm/", custom_load=True, num_classes=21843, ), "vit_base_patch16_224.augreg_in21k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", hf_hub_id="timm/", custom_load=True, num_classes=21843, ), "vit_base_patch8_224.augreg_in21k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz", hf_hub_id="timm/", custom_load=True, num_classes=21843, ), "vit_large_patch16_224.augreg_in21k": _cfg( url="https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz", hf_hub_id="timm/", custom_load=True, num_classes=21843, ), # SAM trained models (https://arxiv.org/abs/2106.01548) "vit_base_patch32_224.sam_in1k": _cfg( url="https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz", custom_load=True, hf_hub_id="timm/", ), "vit_base_patch16_224.sam_in1k": _cfg( url="https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz", custom_load=True, hf_hub_id="timm/", ), # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) "vit_small_patch16_224.dino": _cfg( url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", hf_hub_id="timm/", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), "vit_small_patch8_224.dino": _cfg( url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth", hf_hub_id="timm/", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), "vit_base_patch16_224.dino": _cfg( url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth", hf_hub_id="timm/", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), "vit_base_patch8_224.dino": _cfg( url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth", hf_hub_id="timm/", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only) "vit_small_patch14_dinov2.lvd142m": _cfg( url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth", hf_hub_id="timm/", license="apache-2.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518), crop_pct=1.0, ), "vit_base_patch14_dinov2.lvd142m": _cfg( url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth", hf_hub_id="timm/", license="apache-2.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518), crop_pct=1.0, ), "vit_large_patch14_dinov2.lvd142m": _cfg( url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", hf_hub_id="timm/", license="apache-2.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518), crop_pct=1.0, ), "vit_giant_patch14_dinov2.lvd142m": _cfg( url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth", hf_hub_id="timm/", license="apache-2.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518), crop_pct=1.0, ), # DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only) "vit_small_patch14_reg4_dinov2.lvd142m": _cfg( url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth", hf_hub_id="timm/", license="apache-2.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518), crop_pct=1.0, ), "vit_base_patch14_reg4_dinov2.lvd142m": _cfg( url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth", hf_hub_id="timm/", license="apache-2.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518), crop_pct=1.0, ), "vit_large_patch14_reg4_dinov2.lvd142m": _cfg( url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth", hf_hub_id="timm/", license="apache-2.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518), crop_pct=1.0, ), "vit_giant_patch14_reg4_dinov2.lvd142m": _cfg( url="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth", hf_hub_id="timm/", license="apache-2.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518), crop_pct=1.0, ), # ViT ImageNet-21K-P pretraining by MILL "vit_base_patch16_224_miil.in21k": _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth", hf_hub_id="timm/", mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), crop_pct=0.875, interpolation="bilinear", num_classes=11221, ), "vit_base_patch16_224_miil.in21k_ft_in1k": _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth", hf_hub_id="timm/", mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), crop_pct=0.875, interpolation="bilinear", ), # Custom timm variants "vit_base_patch16_rpn_224.sw_in1k": _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth", hf_hub_id="timm/", ), "vit_medium_patch16_gap_240.sw_in12k": _cfg( hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821 ), "vit_medium_patch16_gap_256.sw_in12k_ft_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_medium_patch16_gap_384.sw_in12k_ft_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 384, 384), crop_pct=0.95, crop_mode="squash" ), "vit_base_patch16_gap_224": _cfg(), # CLIP pretrained image tower and related fine-tuned weights "vit_base_patch32_clip_224.laion2b_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD ), "vit_base_patch32_clip_384.laion2b_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), ), "vit_base_patch32_clip_448.laion2b_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448), ), "vit_base_patch16_clip_224.laion2b_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95 ), "vit_base_patch16_clip_384.laion2b_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode="squash", ), "vit_large_patch14_clip_224.laion2b_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, ), "vit_large_patch14_clip_336.laion2b_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode="squash", ), "vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0 ), "vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode="squash", ), "vit_base_patch32_clip_224.openai_ft_in12k_in1k": _cfg( # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', # FIXME weight exists, need to push mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, ), "vit_base_patch32_clip_384.openai_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95, input_size=(3, 384, 384), crop_mode="squash", ), "vit_base_patch16_clip_224.openai_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95 ), "vit_base_patch16_clip_384.openai_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95, input_size=(3, 384, 384), crop_mode="squash", ), "vit_large_patch14_clip_224.openai_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0 ), "vit_large_patch14_clip_336.openai_ft_in12k_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode="squash", ), "vit_base_patch32_clip_224.laion2b_ft_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD ), "vit_base_patch16_clip_224.laion2b_ft_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0 ), "vit_base_patch16_clip_384.laion2b_ft_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode="squash", ), "vit_large_patch14_clip_224.laion2b_ft_in1k": _cfg( hf_hub_id="timm/", mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, ), "vit_large_patch14_clip_336.laion2b_ft_in1k": _cfg( hf_hub_id="timm/", mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode="squash", ), "vit_huge_patch14_clip_224.laion2b_ft_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0 ), "vit_huge_patch14_clip_336.laion2b_ft_in1k": _cfg( hf_hub_id="", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode="squash", ), "vit_base_patch32_clip_224.openai_ft_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD ), "vit_base_patch16_clip_224.openai_ft_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD ), "vit_base_patch16_clip_384.openai_ft_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384), crop_mode="squash", ), "vit_large_patch14_clip_224.openai_ft_in1k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0 ), "vit_base_patch32_clip_224.laion2b_ft_in12k": _cfg( # hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', # FIXME weight exists, need to push mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821, ), "vit_base_patch16_clip_224.laion2b_ft_in12k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821 ), "vit_large_patch14_clip_224.laion2b_ft_in12k": _cfg( hf_hub_id="timm/", mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821, ), "vit_huge_patch14_clip_224.laion2b_ft_in12k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821, ), "vit_base_patch32_clip_224.openai_ft_in12k": _cfg( # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', # FIXME weight exists, need to push mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821, ), "vit_base_patch16_clip_224.openai_ft_in12k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821 ), "vit_large_patch14_clip_224.openai_ft_in12k": _cfg( hf_hub_id="timm/", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821, ), "vit_base_patch32_clip_224.laion2b": _cfg( hf_hub_id="laion/CLIP-ViT-B-32-laion2B-s34B-b79K", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512, ), "vit_base_patch16_clip_224.laion2b": _cfg( hf_hub_id="laion/CLIP-ViT-B-16-laion2B-s34B-b88K", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512, ), "vit_large_patch14_clip_224.laion2b": _cfg( hf_hub_id="laion/CLIP-ViT-L-14-laion2B-s32B-b82K", hf_hub_filename="open_clip_pytorch_model.bin", mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768, ), "vit_huge_patch14_clip_224.laion2b": _cfg( hf_hub_id="laion/CLIP-ViT-H-14-laion2B-s32B-b79K", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024, ), "vit_giant_patch14_clip_224.laion2b": _cfg( hf_hub_id="laion/CLIP-ViT-g-14-laion2B-s12B-b42K", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024, ), "vit_gigantic_patch14_clip_224.laion2b": _cfg( hf_hub_id="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280, ), "vit_base_patch32_clip_224.datacompxl": _cfg( hf_hub_id="laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512, ), "vit_base_patch32_clip_256.datacompxl": _cfg( hf_hub_id="laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 256, 256), num_classes=512, ), "vit_base_patch16_clip_224.datacompxl": _cfg( hf_hub_id="laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512, ), "vit_large_patch14_clip_224.datacompxl": _cfg( hf_hub_id="laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768, ), "vit_base_patch16_clip_224.dfn2b": _cfg( hf_hub_id="apple/DFN2B-CLIP-ViT-B-16", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512, ), "vit_large_patch14_clip_224.dfn2b": _cfg( hf_hub_id="apple/DFN2B-CLIP-ViT-L-14", hf_hub_filename="open_clip_pytorch_model.bin", notes=("natively QuickGELU, use quickgelu model variant for original results",), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768, ), "vit_huge_patch14_clip_224.dfn5b": _cfg( hf_hub_id="apple/DFN5B-CLIP-ViT-H-14", hf_hub_filename="open_clip_pytorch_model.bin", notes=("natively QuickGELU, use quickgelu model variant for original results",), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024, ), "vit_huge_patch14_clip_378.dfn5b": _cfg( hf_hub_id="apple/DFN5B-CLIP-ViT-H-14-378", hf_hub_filename="open_clip_pytorch_model.bin", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, notes=("natively QuickGELU, use quickgelu model variant for original results",), crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024, ), "vit_base_patch32_clip_224.metaclip_2pt5b": _cfg( hf_hub_id="facebook/metaclip-b32-fullcc2.5b", hf_hub_filename="metaclip_b32_fullcc2.5b.bin", license="cc-by-nc-4.0", notes=("natively QuickGELU, use quickgelu model variant for original results",), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512, ), "vit_base_patch16_clip_224.metaclip_2pt5b": _cfg( hf_hub_id="facebook/metaclip-b16-fullcc2.5b", hf_hub_filename="metaclip_b16_fullcc2.5b.bin", license="cc-by-nc-4.0", notes=("natively QuickGELU, use quickgelu model variant for original results",), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512, ), "vit_large_patch14_clip_224.metaclip_2pt5b": _cfg( hf_hub_id="facebook/metaclip-l14-fullcc2.5b", hf_hub_filename="metaclip_l14_fullcc2.5b.bin", license="cc-by-nc-4.0", notes=("natively QuickGELU, use quickgelu model variant for original results",), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768, ), "vit_huge_patch14_clip_224.metaclip_2pt5b": _cfg( hf_hub_id="facebook/metaclip-h14-fullcc2.5b", hf_hub_filename="metaclip_h14_fullcc2.5b.bin", license="cc-by-nc-4.0", notes=("natively QuickGELU, use quickgelu model variant for original results",), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024, ), "vit_base_patch32_clip_224.openai": _cfg( hf_hub_id="timm/vit_base_patch32_clip_224.openai", notes=("natively QuickGELU, use quickgelu model variant for original results",), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512, ), "vit_base_patch16_clip_224.openai": _cfg( hf_hub_id="timm/vit_base_patch16_clip_224.openai", notes=("natively QuickGELU, use quickgelu model variant for original results",), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512, ), "vit_large_patch14_clip_224.openai": _cfg( hf_hub_id="timm/vit_large_patch14_clip_224.openai", notes=("natively QuickGELU, use quickgelu model variant for original results",), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768, ), "vit_large_patch14_clip_336.openai": _cfg( hf_hub_id="timm/vit_large_patch14_clip_336.openai", hf_hub_filename="open_clip_pytorch_model.bin", notes=("natively QuickGELU, use quickgelu model variant for original results",), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), num_classes=768, ), # experimental (may be removed) "vit_base_patch32_plus_256.untrained": _cfg( url="", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_base_patch16_plus_240.untrained": _cfg( url="", input_size=(3, 240, 240), crop_pct=0.95 ), "vit_small_patch16_36x1_224.untrained": _cfg(url=""), "vit_small_patch16_18x2_224.untrained": _cfg(url=""), "vit_base_patch16_18x2_224.untrained": _cfg(url=""), # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip "eva_large_patch14_196.in22k_ft_in22k_in1k": _cfg( # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', hf_hub_id="timm/", license="mit", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 196, 196), crop_pct=1.0, ), "eva_large_patch14_336.in22k_ft_in22k_in1k": _cfg( # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', hf_hub_id="timm/", license="mit", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 336, 336), crop_pct=1.0, crop_mode="squash", ), "eva_large_patch14_196.in22k_ft_in1k": _cfg( # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', hf_hub_id="timm/", license="mit", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 196, 196), crop_pct=1.0, ), "eva_large_patch14_336.in22k_ft_in1k": _cfg( # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', hf_hub_id="timm/", license="mit", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 336, 336), crop_pct=1.0, crop_mode="squash", ), "flexivit_small.1200ep_in1k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, ), "flexivit_small.600ep_in1k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, ), "flexivit_small.300ep_in1k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, ), "flexivit_base.1200ep_in1k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, ), "flexivit_base.600ep_in1k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, ), "flexivit_base.300ep_in1k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, ), "flexivit_base.1000ep_in21k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843, ), "flexivit_base.300ep_in21k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843, ), "flexivit_large.1200ep_in1k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, ), "flexivit_large.600ep_in1k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, ), "flexivit_large.300ep_in1k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, ), "flexivit_base.patch16_in21k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843, ), "flexivit_base.patch30_in21k": _cfg( url="https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz", custom_load=True, hf_hub_id="timm/", input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843, ), "vit_base_patch16_xp_224.untrained": _cfg(url=""), "vit_large_patch14_xp_224.untrained": _cfg(url=""), "vit_huge_patch14_xp_224.untrained": _cfg(url=""), "vit_base_patch16_224.mae": _cfg( url="https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth", hf_hub_id="timm/", license="cc-by-nc-4.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), "vit_large_patch16_224.mae": _cfg( url="https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth", hf_hub_id="timm/", license="cc-by-nc-4.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), "vit_huge_patch14_224.mae": _cfg( url="https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth", hf_hub_id="timm/", license="cc-by-nc-4.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), "vit_huge_patch14_gap_224.in1k_ijepa": _cfg( url="https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar", # hf_hub_id='timm/', license="cc-by-nc-4.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), "vit_huge_patch14_gap_224.in22k_ijepa": _cfg( url="https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar", # hf_hub_id='timm/', license="cc-by-nc-4.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), "vit_huge_patch16_gap_448.in1k_ijepa": _cfg( url="https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar", # hf_hub_id='timm/', license="cc-by-nc-4.0", input_size=(3, 448, 448), crop_pct=1.0, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), "vit_giant_patch16_gap_224.in22k_ijepa": _cfg( url="https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar", # hf_hub_id='timm/', license="cc-by-nc-4.0", mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, ), "vit_base_patch16_siglip_224.webli": _cfg( hf_hub_id="timm/ViT-B-16-SigLIP", hf_hub_filename="open_clip_pytorch_model.bin", num_classes=0, ), "vit_base_patch16_siglip_256.webli": _cfg( hf_hub_id="timm/ViT-B-16-SigLIP-256", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 256, 256), num_classes=0, ), "vit_base_patch16_siglip_384.webli": _cfg( hf_hub_id="timm/ViT-B-16-SigLIP-384", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 384, 384), num_classes=0, ), "vit_base_patch16_siglip_512.webli": _cfg( hf_hub_id="timm/ViT-B-16-SigLIP-512", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 512, 512), num_classes=0, ), "vit_large_patch16_siglip_256.webli": _cfg( hf_hub_id="timm/ViT-L-16-SigLIP-256", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 256, 256), num_classes=0, ), "vit_large_patch16_siglip_384.webli": _cfg( hf_hub_id="timm/ViT-L-16-SigLIP-384", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 384, 384), num_classes=0, ), "vit_so400m_patch14_siglip_224.webli": _cfg( hf_hub_id="timm/ViT-SO400M-14-SigLIP", hf_hub_filename="open_clip_pytorch_model.bin", num_classes=0, ), "vit_so400m_patch14_siglip_384.webli": _cfg( hf_hub_id="timm/ViT-SO400M-14-SigLIP-384", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 384, 384), num_classes=0, ), "vit_base_patch16_siglip_gap_224.webli": _cfg( hf_hub_id="timm/ViT-B-16-SigLIP", hf_hub_filename="open_clip_pytorch_model.bin", num_classes=0, ), "vit_base_patch16_siglip_gap_256.webli": _cfg( hf_hub_id="timm/ViT-B-16-SigLIP-256", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 256, 256), num_classes=0, ), "vit_base_patch16_siglip_gap_384.webli": _cfg( hf_hub_id="timm/ViT-B-16-SigLIP-384", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 384, 384), num_classes=0, ), "vit_base_patch16_siglip_gap_512.webli": _cfg( hf_hub_id="timm/ViT-B-16-SigLIP-512", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 512, 512), num_classes=0, ), "vit_large_patch16_siglip_gap_256.webli": _cfg( hf_hub_id="timm/ViT-L-16-SigLIP-256", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 256, 256), num_classes=0, ), "vit_large_patch16_siglip_gap_384.webli": _cfg( hf_hub_id="timm/ViT-L-16-SigLIP-384", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 384, 384), num_classes=0, ), "vit_so400m_patch14_siglip_gap_224.webli": _cfg( hf_hub_id="timm/ViT-SO400M-14-SigLIP", hf_hub_filename="open_clip_pytorch_model.bin", num_classes=0, ), "vit_so400m_patch14_siglip_gap_224.pali_mix": _cfg( hf_hub_id="google/paligemma-3b-mix-224-jax", hf_hub_filename="paligemma-3b-mix-224.npz", custom_load="hf", num_classes=0, ), "vit_so400m_patch14_siglip_gap_224.pali_pt": _cfg( hf_hub_id="google/paligemma-3b-pt-224-jax", hf_hub_filename="paligemma-3b-pt-224.npz", custom_load="hf", num_classes=0, ), "vit_so400m_patch14_siglip_gap_384.webli": _cfg( hf_hub_id="timm/ViT-SO400M-14-SigLIP-384", hf_hub_filename="open_clip_pytorch_model.bin", input_size=(3, 384, 384), crop_pct=1.0, num_classes=0, ), "vit_so400m_patch14_siglip_gap_448.pali_mix": _cfg( hf_hub_id="google/paligemma-3b-mix-448-jax", hf_hub_filename="paligemma-3b-mix-448.npz", custom_load="hf", input_size=(3, 448, 448), crop_pct=1.0, num_classes=0, ), "vit_so400m_patch14_siglip_gap_448.pali_pt": _cfg( hf_hub_id="google/paligemma-3b-pt-448-jax", hf_hub_filename="paligemma-3b-pt-448.npz", custom_load="hf", input_size=(3, 448, 448), crop_pct=1.0, num_classes=0, ), "vit_so400m_patch14_siglip_gap_896.pali_pt": _cfg( hf_hub_id="google/paligemma-3b-pt-896-jax", hf_hub_filename="paligemma-3b-pt-896.npz", custom_load="hf", input_size=(3, 896, 896), crop_pct=1.0, num_classes=0, ), "vit_xsmall_patch16_clip_224.tinyclip_yfcc15m": _cfg( hf_hub_id="timm/", hf_hub_filename="open_clip_pytorch_model.bin", license="mit", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512, ), "vit_medium_patch32_clip_224.tinyclip_laion400m": _cfg( hf_hub_id="timm/", hf_hub_filename="open_clip_pytorch_model.bin", license="mit", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512, ), "vit_medium_patch16_clip_224.tinyclip_yfcc15m": _cfg( hf_hub_id="timm/", hf_hub_filename="open_clip_pytorch_model.bin", license="mit", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512, ), "vit_betwixt_patch32_clip_224.tinyclip_laion400m": _cfg( hf_hub_id="timm/", hf_hub_filename="open_clip_pytorch_model.bin", license="mit", mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512, ), "vit_wee_patch16_reg1_gap_256.sbb_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_pwee_patch16_reg1_gap_256.sbb_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_little_patch16_reg1_gap_256.sbb_in12k": _cfg( hf_hub_id="timm/", num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95 ), "vit_little_patch16_reg4_gap_256.sbb_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_medium_patch16_reg1_gap_256.sbb_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_medium_patch16_reg4_gap_256.sbb_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_medium_patch16_reg4_gap_256.sbb_in12k": _cfg( hf_hub_id="timm/", num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95 ), "vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_mediumd_patch16_reg4_gap_256.sbb_in12k": _cfg( hf_hub_id="timm/", num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95 ), "vit_betwixt_patch16_reg1_gap_256.sbb_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_betwixt_patch16_reg4_gap_256.sbb_in1k": _cfg( hf_hub_id="timm/", input_size=(3, 256, 256), crop_pct=0.95 ), "vit_betwixt_patch16_reg4_gap_256.sbb_in12k": _cfg( hf_hub_id="timm/", num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95 ), "vit_base_patch16_reg4_gap_256": _cfg(input_size=(3, 256, 256)), "vit_so150m_patch16_reg4_gap_256": _cfg(input_size=(3, 256, 256)), "vit_so150m_patch16_reg4_map_256": _cfg(input_size=(3, 256, 256)), } _quick_gelu_cfgs = [ "vit_large_patch14_clip_224.dfn2b", "vit_huge_patch14_clip_224.dfn5b", "vit_huge_patch14_clip_378.dfn5b", "vit_base_patch32_clip_224.metaclip_2pt5b", "vit_base_patch16_clip_224.metaclip_2pt5b", "vit_large_patch14_clip_224.metaclip_2pt5b", "vit_huge_patch14_clip_224.metaclip_2pt5b", "vit_base_patch32_clip_224.openai", "vit_base_patch16_clip_224.openai", "vit_large_patch14_clip_224.openai", "vit_large_patch14_clip_336.openai", ] default_cfgs.update( {n.replace("_clip_", "_clip_quickgelu_"): default_cfgs[n] for n in _quick_gelu_cfgs} ) default_cfgs = generate_default_cfgs(default_cfgs) def _create_vision_transformer( variant: str, pretrained: bool = False, **kwargs ) -> VisionTransformer: out_indices = kwargs.pop("out_indices", 3) if "flexi" in variant: # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. _filter_fn = partial( checkpoint_filter_fn, interpolation="bilinear", antialias=False ) else: _filter_fn = checkpoint_filter_fn # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln? strict = True if "siglip" in variant and kwargs.get("global_pool", None) != "map": strict = False if "attn_layer" in kwargs: strict = False return build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_filter_fn=_filter_fn, pretrained_strict=strict, feature_cfg=dict(out_indices=out_indices, feature_cls="getter"), **kwargs, ) @register_model def vit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Tiny (Vit-Ti/16)""" model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) model = _create_vision_transformer( "vit_tiny_patch16_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_tiny_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Tiny (Vit-Ti/16) @ 384x384.""" model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) model = _create_vision_transformer( "vit_tiny_patch16_384", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_small_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Small (ViT-S/32)""" model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer( "vit_small_patch32_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_small_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Small (ViT-S/32) at 384x384.""" model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer( "vit_small_patch32_384", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Small (ViT-S/16)""" model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer( "vit_small_patch16_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Small (ViT-S/16)""" model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer( "vit_small_patch16_384", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_small_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Small (ViT-S/8)""" model_args = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6) model = _create_vision_transformer( "vit_small_patch8_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. """ model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer( "vit_base_patch32_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer( "vit_base_patch32_384", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer( "vit_base_patch16_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer( "vit_base_patch16_384", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_args = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12) model = _create_vision_transformer( "vit_base_patch8_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_large_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.""" model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer( "vit_large_patch32_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_large_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer( "vit_large_patch32_384", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer( "vit_large_patch16_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_large_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer( "vit_large_patch16_384", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Large model (ViT-L/14)""" model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16) model = _create_vision_transformer( "vit_large_patch14_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).""" model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16) model = _create_vision_transformer( "vit_huge_patch14_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560""" model_args = dict( patch_size=14, embed_dim=1408, mlp_ratio=48 / 11, depth=40, num_heads=16 ) model = _create_vision_transformer( "vit_giant_patch14_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_gigantic_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560""" model_args = dict( patch_size=14, embed_dim=1664, mlp_ratio=64 / 13, depth=48, num_heads=16 ) model = _create_vision_transformer( "vit_gigantic_patch14_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch16_224_miil(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False ) model = _create_vision_transformer( "vit_base_patch16_224_miil", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_medium_patch16_gap_240(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240""" model_args = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, global_pool="avg", qkv_bias=False, init_values=1e-6, fc_norm=False, ) model = _create_vision_transformer( "vit_medium_patch16_gap_240", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_medium_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256""" model_args = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, global_pool="avg", qkv_bias=False, init_values=1e-6, fc_norm=False, ) model = _create_vision_transformer( "vit_medium_patch16_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384""" model_args = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, global_pool="avg", qkv_bias=False, init_values=1e-6, fc_norm=False, ) model = _create_vision_transformer( "vit_medium_patch16_gap_384", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_betwixt_patch16_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-Betwixt (ViT-b/16) w/o class token, w/ avg-pool @ 256x256""" model_args = dict( patch_size=16, embed_dim=640, depth=12, num_heads=10, class_token=False, global_pool="avg", qkv_bias=False, init_values=1e-6, fc_norm=False, ) model = _create_vision_transformer( "vit_medium_patch16_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_base_patch16_gap_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_huge_patch14_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Huge model (ViT-H/14) w/ no class token, avg pool""" model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_huge_patch14_gap_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_huge_patch16_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448""" model_args = dict( patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_huge_patch16_gap_448", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool""" model_args = dict( patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48 / 11, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_giant_patch16_gap_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_xsmall_patch16_clip_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: # TinyCLIP 8M model_args = dict( embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm ) model = _create_vision_transformer( "vit_xsmall_patch16_clip_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_medium_patch32_clip_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: # TinyCLIP 40M model_args = dict( patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_medium_patch32_clip_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_medium_patch16_clip_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: # TinyCLIP 39M model_args = dict( embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm ) model = _create_vision_transformer( "vit_medium_patch16_clip_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_betwixt_patch32_clip_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: # TinyCLIP 61M model_args = dict( patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_betwixt_patch32_clip_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-B/32 CLIP image tower @ 224x224""" model_args = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_base_patch32_clip_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch32_clip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-B/32 CLIP image tower @ 256x256""" model_args = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_base_patch32_clip_256", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch32_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-B/32 CLIP image tower @ 384x384""" model_args = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_base_patch32_clip_384", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-B/32 CLIP image tower @ 448x448""" model_args = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_base_patch32_clip_448", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-B/16 CLIP image tower""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_base_patch16_clip_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-B/16 CLIP image tower @ 384x384""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_base_patch16_clip_384", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Large model (ViT-L/14) CLIP image tower""" model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_large_patch14_clip_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Large model (ViT-L/14) CLIP image tower @ 336x336""" model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_large_patch14_clip_336", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Huge model (ViT-H/14) CLIP image tower.""" model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_huge_patch14_clip_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336""" model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_huge_patch14_clip_336", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_huge_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378""" model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_huge_patch14_clip_378", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 Pretrained weights from CLIP image tower. """ model_args = dict( patch_size=14, embed_dim=1408, mlp_ratio=48 / 11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_giant_patch14_clip_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_gigantic_patch14_clip_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 Pretrained weights from CLIP image tower. """ model_args = dict( patch_size=14, embed_dim=1664, mlp_ratio=64 / 13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, ) model = _create_vision_transformer( "vit_gigantic_patch14_clip_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch32_clip_quickgelu_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-B/32 CLIP image tower @ 224x224""" model_args = dict( patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, act_layer="quick_gelu", ) model = _create_vision_transformer( "vit_base_patch32_clip_quickgelu_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_clip_quickgelu_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-B/16 CLIP image tower w/ QuickGELU act""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, act_layer="quick_gelu", ) model = _create_vision_transformer( "vit_base_patch16_clip_quickgelu_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_large_patch14_clip_quickgelu_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act""" model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, act_layer="quick_gelu", ) model = _create_vision_transformer( "vit_large_patch14_clip_quickgelu_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_large_patch14_clip_quickgelu_336( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act""" model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, act_layer="quick_gelu", ) model = _create_vision_transformer( "vit_large_patch14_clip_quickgelu_336", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_huge_patch14_clip_quickgelu_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act.""" model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, act_layer="quick_gelu", ) model = _create_vision_transformer( "vit_huge_patch14_clip_quickgelu_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_huge_patch14_clip_quickgelu_378( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act""" model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, act_layer="quick_gelu", ) model = _create_vision_transformer( "vit_huge_patch14_clip_quickgelu_378", pretrained=pretrained, **dict(model_args, **kwargs), ) return model # Experimental models below @register_model def vit_base_patch32_plus_256(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base (ViT-B/32+)""" model_args = dict( patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5 ) model = _create_vision_transformer( "vit_base_patch32_plus_256", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch16_plus_240(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base (ViT-B/16+)""" model_args = dict( patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5 ) model = _create_vision_transformer( "vit_base_patch16_plus_240", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch16_rpn_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base (ViT-B/16) w/ residual post-norm""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, block_fn=ResPostBlock, global_pool="avg", ) model = _create_vision_transformer( "vit_base_patch16_rpn_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_small_patch16_36x1_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ model_args = dict( patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5 ) model = _create_vision_transformer( "vit_small_patch16_36x1_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_small_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ model_args = dict( patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock, ) model = _create_vision_transformer( "vit_small_patch16_18x2_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 """ model_args = dict( patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock, ) model = _create_vision_transformer( "vit_base_patch16_18x2_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def eva_large_patch14_196(pretrained: bool = False, **kwargs) -> VisionTransformer: """EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool="avg" ) model = _create_vision_transformer( "eva_large_patch14_196", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def eva_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool="avg" ) model = _create_vision_transformer( "eva_large_patch14_336", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def flexivit_small(pretrained: bool = False, **kwargs) -> VisionTransformer: """FlexiViT-Small""" model_args = dict( patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True ) model = _create_vision_transformer( "flexivit_small", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def flexivit_base(pretrained: bool = False, **kwargs) -> VisionTransformer: """FlexiViT-Base""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True ) model = _create_vision_transformer( "flexivit_base", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def flexivit_large(pretrained: bool = False, **kwargs) -> VisionTransformer: """FlexiViT-Large""" model_args = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True ) model = _create_vision_transformer( "flexivit_large", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch16_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True, norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, ) model = _create_vision_transformer( "vit_base_patch16_xp_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_large_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.""" model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True, norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, ) model = _create_vision_transformer( "vit_large_patch14_xp_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled.""" model_args = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True, norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, ) model = _create_vision_transformer( "vit_huge_patch14_xp_224", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-S/14 for DINOv2""" model_args = dict( patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5 ) model = _create_vision_transformer( "vit_small_patch14_dinov2", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-B/14 for DINOv2""" model_args = dict( patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5 ) model = _create_vision_transformer( "vit_base_patch14_dinov2", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-L/14 for DINOv2""" model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5 ) model = _create_vision_transformer( "vit_large_patch14_dinov2", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ViT-G/14 for DINOv2""" # The hidden_features of SwiGLU is calculated by: # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 # When embed_dim=1536, hidden_features=4096 # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 model_args = dict( patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, ) model = _create_vision_transformer( "vit_giant_patch14_dinov2", pretrained=pretrained, **dict(model_args, **kwargs) ) return model @register_model def vit_small_patch14_reg4_dinov2( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-S/14 for DINOv2 w/ 4 registers""" model_args = dict( patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, reg_tokens=4, no_embed_class=True, ) model = _create_vision_transformer( "vit_small_patch14_reg4_dinov2", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch14_reg4_dinov2( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-B/14 for DINOv2 w/ 4 registers""" model_args = dict( patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, reg_tokens=4, no_embed_class=True, ) model = _create_vision_transformer( "vit_base_patch14_reg4_dinov2", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_large_patch14_reg4_dinov2( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-L/14 for DINOv2 w/ 4 registers""" model_args = dict( patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, reg_tokens=4, no_embed_class=True, ) model = _create_vision_transformer( "vit_large_patch14_reg4_dinov2", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_giant_patch14_reg4_dinov2( pretrained: bool = False, **kwargs ) -> VisionTransformer: """ViT-G/14 for DINOv2""" # The hidden_features of SwiGLU is calculated by: # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 # When embed_dim=1536, hidden_features=4096 # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 model_args = dict( patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True, ) model = _create_vision_transformer( "vit_giant_patch14_reg4_dinov2", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_siglip_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool="map", ) model = _create_vision_transformer( "vit_base_patch16_siglip_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_siglip_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool="map", ) model = _create_vision_transformer( "vit_base_patch16_siglip_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_siglip_384( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool="map", ) model = _create_vision_transformer( "vit_base_patch16_siglip_384", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_siglip_512( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool="map", ) model = _create_vision_transformer( "vit_base_patch16_siglip_512", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_large_patch16_siglip_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool="map", ) model = _create_vision_transformer( "vit_large_patch16_siglip_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_large_patch16_siglip_384( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool="map", ) model = _create_vision_transformer( "vit_large_patch16_siglip_384", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_so400m_patch14_siglip_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool="map", ) model = _create_vision_transformer( "vit_so400m_patch14_siglip_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_so400m_patch14_siglip_384( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool="map", ) model = _create_vision_transformer( "vit_so400m_patch14_siglip_384", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_siglip_gap_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_base_patch16_siglip_gap_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_siglip_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_base_patch16_siglip_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_siglip_gap_384( pretrained: bool = False, **kwargs ) -> VisionTransformer: """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_base_patch16_siglip_gap_384", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_siglip_gap_512( pretrained: bool = False, **kwargs ) -> VisionTransformer: """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_base_patch16_siglip_gap_512", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_large_patch16_siglip_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_large_patch16_siglip_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_large_patch16_siglip_gap_384( pretrained: bool = False, **kwargs ) -> VisionTransformer: """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_large_patch16_siglip_gap_384", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_so400m_patch14_siglip_gap_224( pretrained: bool = False, **kwargs ) -> VisionTransformer: """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_so400m_patch14_siglip_gap_224", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_so400m_patch14_siglip_gap_384( pretrained: bool = False, **kwargs ) -> VisionTransformer: """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_so400m_patch14_siglip_gap_384", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_so400m_patch14_siglip_gap_448( pretrained: bool = False, **kwargs ) -> VisionTransformer: """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_so400m_patch14_siglip_gap_448", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_so400m_patch14_siglip_gap_896( pretrained: bool = False, **kwargs ) -> VisionTransformer: """A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_so400m_patch14_siglip_gap_896", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_wee_patch16_reg1_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5, class_token=False, no_embed_class=True, reg_tokens=1, global_pool="avg", ) model = _create_vision_transformer( "vit_wee_patch16_reg1_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_pwee_patch16_reg1_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5, class_token=False, no_embed_class=True, reg_tokens=1, global_pool="avg", block_fn=ParallelScalingBlock, ) model = _create_vision_transformer( "vit_pwee_patch16_reg1_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_little_patch16_reg1_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6, class_token=False, no_embed_class=True, reg_tokens=1, global_pool="avg", ) model = _create_vision_transformer( "vit_little_patch16_reg1_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_little_patch16_reg4_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6, class_token=False, no_embed_class=True, reg_tokens=4, global_pool="avg", ) model = _create_vision_transformer( "vit_little_patch16_reg4_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_medium_patch16_reg1_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5, class_token=False, no_embed_class=True, reg_tokens=1, global_pool="avg", ) model = _create_vision_transformer( "vit_medium_patch16_reg1_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_medium_patch16_reg4_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5, class_token=False, no_embed_class=True, reg_tokens=4, global_pool="avg", ) model = _create_vision_transformer( "vit_medium_patch16_reg4_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_mediumd_patch16_reg4_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5, class_token=False, no_embed_class=True, reg_tokens=4, global_pool="avg", ) model = _create_vision_transformer( "vit_mediumd_patch16_reg4_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_betwixt_patch16_reg1_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5, class_token=False, no_embed_class=True, reg_tokens=1, global_pool="avg", ) model = _create_vision_transformer( "vit_betwixt_patch16_reg1_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_betwixt_patch16_reg4_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5, class_token=False, no_embed_class=True, reg_tokens=4, global_pool="avg", ) model = _create_vision_transformer( "vit_betwixt_patch16_reg4_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_base_patch16_reg4_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, no_embed_class=True, global_pool="avg", reg_tokens=4, ) model = _create_vision_transformer( "vit_base_patch16_reg4_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_so150m_patch16_reg4_map_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, class_token=False, reg_tokens=4, global_pool="map", ) model = _create_vision_transformer( "vit_so150m_patch16_reg4_map_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model @register_model def vit_so150m_patch16_reg4_gap_256( pretrained: bool = False, **kwargs ) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, class_token=False, reg_tokens=4, global_pool="avg", fc_norm=False, ) model = _create_vision_transformer( "vit_so150m_patch16_reg4_gap_256", pretrained=pretrained, **dict(model_args, **kwargs), ) return model register_model_deprecations( __name__, { "vit_tiny_patch16_224_in21k": "vit_tiny_patch16_224.augreg_in21k", "vit_small_patch32_224_in21k": "vit_small_patch32_224.augreg_in21k", "vit_small_patch16_224_in21k": "vit_small_patch16_224.augreg_in21k", "vit_base_patch32_224_in21k": "vit_base_patch32_224.augreg_in21k", "vit_base_patch16_224_in21k": "vit_base_patch16_224.augreg_in21k", "vit_base_patch8_224_in21k": "vit_base_patch8_224.augreg_in21k", "vit_large_patch32_224_in21k": "vit_large_patch32_224.orig_in21k", "vit_large_patch16_224_in21k": "vit_large_patch16_224.augreg_in21k", "vit_huge_patch14_224_in21k": "vit_huge_patch14_224.orig_in21k", "vit_base_patch32_224_sam": "vit_base_patch32_224.sam", "vit_base_patch16_224_sam": "vit_base_patch16_224.sam", "vit_small_patch16_224_dino": "vit_small_patch16_224.dino", "vit_small_patch8_224_dino": "vit_small_patch8_224.dino", "vit_base_patch16_224_dino": "vit_base_patch16_224.dino", "vit_base_patch8_224_dino": "vit_base_patch8_224.dino", "vit_base_patch16_224_miil_in21k": "vit_base_patch16_224_miil.in21k", "vit_base_patch32_224_clip_laion2b": "vit_base_patch32_clip_224.laion2b", "vit_large_patch14_224_clip_laion2b": "vit_large_patch14_clip_224.laion2b", "vit_huge_patch14_224_clip_laion2b": "vit_huge_patch14_clip_224.laion2b", "vit_giant_patch14_224_clip_laion2b": "vit_giant_patch14_clip_224.laion2b", }, )