Last commit not found
# © Recursion Pharmaceuticals 2024 | |
import timm.models.vision_transformer as vit | |
import torch | |
def generate_2d_sincos_pos_embeddings( | |
embedding_dim: int, | |
length: int, | |
scale: float = 10000.0, | |
use_class_token: bool = True, | |
num_modality: int = 1, | |
) -> torch.nn.Parameter: | |
""" | |
Generate 2Dimensional sin/cosine positional embeddings | |
Parameters | |
---------- | |
embedding_dim : int | |
embedding dimension used in vit | |
length : int | |
number of tokens along height or width of image after patching (assuming square) | |
scale : float | |
scale for sin/cos functions | |
use_class_token : bool | |
True - add zero vector to be added to class_token, False - no vector added | |
num_modality: number of modalities. If 0, a single modality is assumed. | |
Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced. | |
Returns | |
------- | |
positional_encoding : torch.Tensor | |
positional encoding to add to vit patch encodings | |
[num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim] | |
(w/ or w/o cls_token) | |
""" | |
linear_positions = torch.arange(length, dtype=torch.float32) | |
height_mesh, width_mesh = torch.meshgrid( | |
linear_positions, linear_positions, indexing="ij" | |
) | |
positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings | |
positional_weights = ( | |
torch.arange(positional_dim, dtype=torch.float32) / positional_dim | |
) | |
positional_weights = 1.0 / (scale**positional_weights) | |
height_weights = torch.outer(height_mesh.flatten(), positional_weights) | |
width_weights = torch.outer(width_mesh.flatten(), positional_weights) | |
positional_encoding = torch.cat( | |
[ | |
torch.sin(height_weights), | |
torch.cos(height_weights), | |
torch.sin(width_weights), | |
torch.cos(width_weights), | |
], | |
dim=1, | |
)[None, :, :] | |
# repeat positional encoding for multiple channel modalities | |
positional_encoding = positional_encoding.repeat(1, num_modality, 1) | |
if use_class_token: | |
class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32) | |
positional_encoding = torch.cat([class_token, positional_encoding], dim=1) | |
positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False) | |
return positional_encoding | |
class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc] | |
def __init__( | |
self, | |
img_size: int, | |
patch_size: int, | |
embed_dim: int, | |
bias: bool = True, | |
) -> None: | |
super().__init__( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=1, # in_chans is used by self.proj, which we override anyway | |
embed_dim=embed_dim, | |
norm_layer=None, | |
flatten=False, | |
bias=bias, | |
) | |
# channel-agnostic MAE has a single projection for all chans | |
self.proj = torch.nn.Conv2d( | |
1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
in_chans = x.shape[1] | |
x = torch.stack( | |
[self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2 | |
) # single project for all chans | |
x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC | |
return x | |
class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc] | |
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: | |
# rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586 | |
to_cat = [] | |
if self.cls_token is not None: | |
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) | |
# TODO: upgrade timm to get access to register tokens | |
# if self.vit_backbone.reg_token is not None: | |
# to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) | |
# MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs | |
# this supports having CA-MAEs actually be channel-agnostic at inference time | |
if self.no_embed_class: | |
x = x + self.pos_embed[:, : x.shape[1]] | |
if to_cat: | |
x = torch.cat(to_cat + [x], dim=1) | |
else: | |
if to_cat: | |
x = torch.cat(to_cat + [x], dim=1) | |
x = x + self.pos_embed[:, : x.shape[1]] | |
return self.pos_drop(x) # type: ignore[no-any-return] | |
def channel_agnostic_vit( | |
vit_backbone: vit.VisionTransformer, max_in_chans: int | |
) -> vit.VisionTransformer: | |
# replace patch embedding with channel-agnostic version | |
vit_backbone.patch_embed = ChannelAgnosticPatchEmbed( | |
img_size=vit_backbone.patch_embed.img_size[0], | |
patch_size=vit_backbone.patch_embed.patch_size[0], | |
embed_dim=vit_backbone.embed_dim, | |
) | |
# replace positional embedding with channel-agnostic version | |
vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings( | |
embedding_dim=vit_backbone.embed_dim, | |
length=vit_backbone.patch_embed.grid_size[0], | |
use_class_token=vit_backbone.cls_token is not None, | |
num_modality=max_in_chans, | |
) | |
# change the class to be ChannelAgnostic so that it actually uses the new _pos_embed | |
vit_backbone.__class__ = ChannelAgnosticViT | |
return vit_backbone | |
def sincos_positional_encoding_vit( | |
vit_backbone: vit.VisionTransformer, scale: float = 10000.0 | |
) -> vit.VisionTransformer: | |
"""Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model. | |
Parameters | |
---------- | |
vit_backbone : timm.models.vision_transformer.VisionTransformer | |
the constructed vision transformer from timm | |
scale : float (default 10000.0) | |
hyperparameter for sincos positional embeddings, recommend keeping at 10,000 | |
Returns | |
------- | |
timm.models.vision_transformer.VisionTransformer | |
the same ViT but with fixed no-grad positional encodings to add to vit patch encodings | |
""" | |
# length: number of tokens along height or width of image after patching (assuming square) | |
length = ( | |
vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0] | |
) | |
pos_embeddings = generate_2d_sincos_pos_embeddings( | |
vit_backbone.embed_dim, | |
length=length, | |
scale=scale, | |
use_class_token=vit_backbone.cls_token is not None, | |
) | |
# note, if the model had weight_init == 'skip', this might get overwritten | |
vit_backbone.pos_embed = pos_embeddings | |
return vit_backbone | |
def vit_small_patch16_256(**kwargs): | |
default_kwargs = dict( | |
img_size=256, | |
in_chans=6, | |
num_classes=0, | |
fc_norm=None, | |
class_token=True, | |
drop_path_rate=0.1, | |
init_values=0.0001, | |
block_fn=vit.ParallelScalingBlock, | |
qkv_bias=False, | |
qk_norm=True, | |
) | |
for k, v in kwargs.items(): | |
default_kwargs[k] = v | |
return vit.vit_small_patch16_224(**default_kwargs) | |
def vit_small_patch32_512(**kwargs): | |
default_kwargs = dict( | |
img_size=512, | |
in_chans=6, | |
num_classes=0, | |
fc_norm=None, | |
class_token=True, | |
drop_path_rate=0.1, | |
init_values=0.0001, | |
block_fn=vit.ParallelScalingBlock, | |
qkv_bias=False, | |
qk_norm=True, | |
) | |
for k, v in kwargs.items(): | |
default_kwargs[k] = v | |
return vit.vit_small_patch32_384(**default_kwargs) | |
def vit_base_patch8_256(**kwargs): | |
default_kwargs = dict( | |
img_size=256, | |
in_chans=6, | |
num_classes=0, | |
fc_norm=None, | |
class_token=True, | |
drop_path_rate=0.1, | |
init_values=0.0001, | |
block_fn=vit.ParallelScalingBlock, | |
qkv_bias=False, | |
qk_norm=True, | |
) | |
for k, v in kwargs.items(): | |
default_kwargs[k] = v | |
return vit.vit_base_patch8_224(**default_kwargs) | |
def vit_base_patch16_256(**kwargs): | |
default_kwargs = dict( | |
img_size=256, | |
in_chans=6, | |
num_classes=0, | |
fc_norm=None, | |
class_token=True, | |
drop_path_rate=0.1, | |
init_values=0.0001, | |
block_fn=vit.ParallelScalingBlock, | |
qkv_bias=False, | |
qk_norm=True, | |
) | |
for k, v in kwargs.items(): | |
default_kwargs[k] = v | |
return vit.vit_base_patch16_224(**default_kwargs) | |
def vit_base_patch32_512(**kwargs): | |
default_kwargs = dict( | |
img_size=512, | |
in_chans=6, | |
num_classes=0, | |
fc_norm=None, | |
class_token=True, | |
drop_path_rate=0.1, | |
init_values=0.0001, | |
block_fn=vit.ParallelScalingBlock, | |
qkv_bias=False, | |
qk_norm=True, | |
) | |
for k, v in kwargs.items(): | |
default_kwargs[k] = v | |
return vit.vit_base_patch32_384(**default_kwargs) | |
def vit_large_patch8_256(**kwargs): | |
default_kwargs = dict( | |
img_size=256, | |
in_chans=6, | |
num_classes=0, | |
fc_norm=None, | |
class_token=True, | |
patch_size=8, | |
embed_dim=1024, | |
depth=24, | |
num_heads=16, | |
drop_path_rate=0.3, | |
init_values=0.0001, | |
block_fn=vit.ParallelScalingBlock, | |
qkv_bias=False, | |
qk_norm=True, | |
) | |
for k, v in kwargs.items(): | |
default_kwargs[k] = v | |
return vit.VisionTransformer(**default_kwargs) | |
def vit_large_patch16_256(**kwargs): | |
default_kwargs = dict( | |
img_size=256, | |
in_chans=6, | |
num_classes=0, | |
fc_norm=None, | |
class_token=True, | |
drop_path_rate=0.3, | |
init_values=0.0001, | |
block_fn=vit.ParallelScalingBlock, | |
qkv_bias=False, | |
qk_norm=True, | |
) | |
for k, v in kwargs.items(): | |
default_kwargs[k] = v | |
return vit.vit_large_patch16_384(**default_kwargs) | |