# Copyright (c) IBM Corp. 2024. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------- # References: # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm # transformers: https://github.com/huggingface/transformers # -------------------------------------------------------- import warnings import logging import numpy as np import torch import torch.nn as nn from einops import rearrange from timm.layers import to_2tuple from timm.models.vision_transformer import Block logger = logging.getLogger(__name__) def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): """ Create 3D sin/cos positional embeddings. Args: embed_dim (int): Embedding dimension. grid_size (tuple[int, int, int] | list[int]): The grid depth, height and width. add_cls_token (bool, *optional*, defaults to False): Whether or not to add a classification (CLS) token. Returns: (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token) """ assert embed_dim % 16 == 0 t_size, h_size, w_size = grid_size w_embed_dim = embed_dim // 16 * 6 h_embed_dim = embed_dim // 16 * 6 t_embed_dim = embed_dim // 16 * 4 w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) if add_cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be even") omega = np.arange(embed_dim // 2, dtype=float) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor): """ Modified torch version of *get_1d_sincos_pos_embed_from_grid()*. embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) - must be float dtype! out: (M, D) """ assert embed_dim % 2 == 0 assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16] omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = torch.sin(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2) emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) return emb def _init_weights(module): """Initialize the weights""" if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def _interpolate_pos_encoding( pos_embed: torch.Tensor, grid_size: tuple[int, int, int] | list[int], patch_size: tuple[int, int, int] | list[int], shape: tuple[int, int, int], embed_dim: int, ): """ Adapted from: - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding, - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194 """ t, h, w = shape t_patches = t // patch_size[0] h_patches = h // patch_size[1] w_patches = w // patch_size[2] if [t_patches, h_patches, w_patches] == grid_size: # No interpolation needed return pos_embed if t_patches != grid_size[0]: # Re-compute pos embedding to handle changed num_frames new_grid_size = (t_patches, *grid_size[1:]) new_pos_embed = get_3d_sincos_pos_embed(pos_embed.shape[-1], new_grid_size, add_cls_token=True) new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0) else: new_grid_size = grid_size new_pos_embed = pos_embed class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:] patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(h_patches, w_patches), mode='bicubic', align_corners=True, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim) return torch.cat((class_pos_embed, patch_pos_embed), dim=1) class PatchEmbed(nn.Module): """3D version of timm.models.vision_transformer.PatchEmbed""" def __init__( self, input_size: tuple[int, int, int] = (1, 224, 224), patch_size: tuple[int, int, int] = (1, 16, 16), in_chans: int = 3, embed_dim: int = 768, norm_layer: nn.Module | None = None, flatten: bool = True, bias: bool = True, ): super().__init__() self.input_size = input_size self.patch_size = patch_size self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)] assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size." self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] self.flatten = flatten self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, T, H, W = x.shape if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1: warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}." f"The border will be ignored, add backbone_padding for pixel-wise tasks.") x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C x = self.norm(x) return x class TemporalEncoder(nn.Module): def __init__(self, embed_dim: int, trainable_scale: bool = False): super().__init__() self.embed_dim = embed_dim self.year_embed_dim = embed_dim // 2 self.julian_day_embed_dim = embed_dim - self.year_embed_dim # If trainable, initialize scale with small number if trainable_scale: self.scale = nn.Parameter(torch.full((1,), 0.1)) else: self.register_buffer('scale', torch.ones(1)) def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None): """ temporal_coords: year and day-of-year info with shape (B, T, 2). tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim). """ shape = temporal_coords.shape[:2] + (-1,) # B, T, -1 year = _get_1d_sincos_embed_from_grid_torch( self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape) julian_day = _get_1d_sincos_embed_from_grid_torch( self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape) embedding = self.scale * torch.cat([year, julian_day], dim=-1) if tokens_per_frame is not None: embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1) return embedding # B, T*tokens_per_frame, embed_dim class LocationEncoder(nn.Module): def __init__(self, embed_dim: int, trainable_scale: bool = False): super().__init__() self.embed_dim = embed_dim self.lat_embed_dim = embed_dim // 2 self.lon_embed_dim = embed_dim - self.lat_embed_dim # If trainable, initialize scale with small number if trainable_scale: self.scale = nn.Parameter(torch.full((1,), 0.1)) else: self.register_buffer('scale', torch.ones(1)) def forward(self, location_coords: torch.Tensor): """ location_coords: lat and lon info with shape (B, 2). """ shape = location_coords.shape[:1] + (1, -1) # B, 1, -1 lat = _get_1d_sincos_embed_from_grid_torch( self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape) lon = _get_1d_sincos_embed_from_grid_torch( self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape) embedding = self.scale * torch.cat([lat, lon], dim=-1) return embedding # B, 1, embed_dim class PrithviViT(nn.Module): """ Prithvi ViT Encoder""" def __init__(self, img_size: int | tuple[int, int] = 224, patch_size: int | tuple[int, int, int] = (1, 16, 16), num_frames: int = 1, in_chans: int = 3, embed_dim: int = 1024, depth: int = 24, num_heads: int = 16, mlp_ratio: float = 4., norm_layer: nn.Module = nn.LayerNorm, coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, drop_path: float = 0., ** kwargs, ): super().__init__() self.in_chans = in_chans self.num_frames = num_frames self.embed_dim = embed_dim self.img_size = to_2tuple(img_size) if isinstance(patch_size, int): patch_size = (1, patch_size, patch_size) # 3D patch embedding self.patch_embed = PatchEmbed( input_size=(num_frames,) + self.img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth # Optional temporal and location embedding coords_encoding = coords_encoding or [] self.temporal_encoding = 'time' in coords_encoding self.location_encoding = 'location' in coords_encoding if self.temporal_encoding: assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}" self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn) if self.location_encoding: self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)) # Transformer layers self.blocks = [] for i in range(depth): self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, drop_path=drop_path,)) self.blocks = nn.ModuleList(self.blocks) self.norm = norm_layer(embed_dim) self.initialize_weights() def initialize_weights(self): # initialize (and freeze) position embeddings by sin-cos embedding pos_embed = get_3d_sincos_pos_embed( self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True ) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d) w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=0.02) self.apply(_init_weights) def random_masking(self, sequence, mask_ratio, noise=None): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. Args: sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`) mask_ratio (float): mask ratio to use. noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is mainly used for testing purposes to control randomness and maintain the reproducibility """ batch_size, seq_length, dim = sequence.shape len_keep = int(seq_length * (1 - mask_ratio)) if noise is None: noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([batch_size, seq_length], device=sequence.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return sequence_unmasked, mask, ids_restore def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]): pos_embed = _interpolate_pos_encoding( pos_embed=self.pos_embed, grid_size=self.patch_embed.grid_size, patch_size=self.patch_embed.patch_size, shape=sample_shape, embed_dim=self.embed_dim, ) return pos_embed def forward( self, x: torch.Tensor, temporal_coords: None | torch.Tensor = None, location_coords: None | torch.Tensor = None, mask_ratio=0.75 ): if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: # add time dim x = x.unsqueeze(2) sample_shape = x.shape[-3:] # embed patches x = self.patch_embed(x) pos_embed = self.interpolate_pos_encoding(sample_shape) # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] if self.temporal_encoding and temporal_coords is not None: num_tokens_per_frame = x.shape[1] // self.num_frames temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) x = x + temporal_encoding if self.location_encoding and location_coords is not None: location_encoding = self.location_embed_enc(location_coords) x = x + location_encoding # masking: length -> length * mask_ratio x, mask, ids_restore = self.random_masking(x, mask_ratio) # append cls token cls_token = self.cls_token + pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) # apply Transformer blocks for block in self.blocks: x = block(x) x = self.norm(x) return x, mask, ids_restore def forward_features( self, x: torch.Tensor, temporal_coords: None | torch.Tensor = None, location_coords: None | torch.Tensor = None, ) -> list[torch.Tensor]: if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1: # add time dim x = x.unsqueeze(2) sample_shape = x.shape[-3:] # embed patches x = self.patch_embed(x) pos_embed = self.interpolate_pos_encoding(sample_shape) # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] if self.temporal_encoding and temporal_coords is not None: num_tokens_per_frame = x.shape[1] // self.num_frames temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame) x = x + temporal_encoding if self.location_encoding and location_coords is not None: location_encoding = self.location_embed_enc(location_coords) x = x + location_encoding # append cls token cls_token = self.cls_token + pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) # apply Transformer blocks out = [] for block in self.blocks: x = block(x) out.append(x.clone()) x = self.norm(x) out[-1] = x return out def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]: out = [] effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0] for x in features: x_no_token = x[:, 1:, :] number_of_tokens = x_no_token.shape[1] tokens_per_timestep = number_of_tokens // effective_time_dim h = int(np.sqrt(tokens_per_timestep)) encoded = rearrange( x_no_token, "batch (t h w) e -> batch (t e) h w", e=self.embed_dim, t=effective_time_dim, h=h, ) out.append(encoded) return out class MAEDecoder(nn.Module): """ Transformer Decoder used in the Prithvi MAE""" def __init__(self, patch_size: int | tuple[int, int, int] = (1, 16, 16), grid_size: list[int] | tuple[int, int, int] = (3, 14, 14), in_chans: int = 3, encoder_embed_dim: int = 1024, decoder_embed_dim: int = 512, depth: int = 8, num_heads: int = 16, mlp_ratio: float = 4., norm_layer: nn.Module = nn.LayerNorm, coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, ): super().__init__() self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) self.decoder_embed_dim = decoder_embed_dim self.grid_size = grid_size if isinstance(patch_size, int): patch_size = (1, patch_size, patch_size) self.patch_size = patch_size self.num_frames = self.grid_size[0] * patch_size[0] num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] # Optional temporal and location embedding coords_encoding = coords_encoding or [] self.temporal_encoding = 'time' in coords_encoding self.location_encoding = 'location' in coords_encoding if self.temporal_encoding: self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn) if self.location_encoding: self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim)) self.decoder_blocks = nn.ModuleList( [Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)] ) self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size[0] * patch_size[1] * patch_size[2] * in_chans, bias=True) self.initialize_weights() def initialize_weights(self): # initialize (and freeze) position embeddings by sin-cos embedding decoder_pos_embed = get_3d_sincos_pos_embed( self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True ) self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.mask_token, std=0.02) self.apply(_init_weights) def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]): pos_embed = _interpolate_pos_encoding( pos_embed=self.decoder_pos_embed, grid_size=self.grid_size, patch_size=self.patch_size, shape=sample_shape, embed_dim=self.decoder_embed_dim, ) return pos_embed def forward( self, hidden_states: torch.Tensor, ids_restore: torch.Tensor, temporal_coords: None | torch.Tensor = None, location_coords: None | torch.Tensor = None, input_size: list[int] = None, ): # embed tokens x = self.decoder_embed(hidden_states) cls_token = x[:, :1, :] # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token # unshuffle x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device)) # add pos embed decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:]) cls_token = cls_token + decoder_pos_embed[:, :1, :] x = x + decoder_pos_embed[:, 1:, :] if self.temporal_encoding and temporal_coords is not None: num_tokens_per_frame = x.shape[1] // self.num_frames temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame) # Add temporal encoding w/o cls token x = x + temporal_encoding if self.location_encoding and location_coords is not None: location_encoding = self.location_embed_dec(location_coords) # Add location encoding w/o cls token x = x + location_encoding # append cls token x = torch.cat([cls_token, x], dim=1) # apply Transformer layers (blocks) for block in self.decoder_blocks: x = block(x) x = self.decoder_norm(x) # predictor projection pred = self.decoder_pred(x) # remove cls token pred = pred[:, 1:, :] return pred class PrithviMAE(nn.Module): """ Prithvi Masked Autoencoder""" def __init__(self, img_size: int | tuple[int, int] = 224, patch_size: int | tuple[int, int, int] = (1, 16, 16), num_frames: int = 4, in_chans: int = 6, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, decoder_embed_dim: int = 512, decoder_depth: int = 8, decoder_num_heads: int = 16, mlp_ratio: float = 4., norm_layer: nn.Module = nn.LayerNorm, norm_pix_loss: bool = False, coords_encoding: list[str] | None = None, coords_scale_learn: bool = False, drop_path: float = 0., mask_ratio: float = 0.75, **kwargs, ): super().__init__() self.encoder = PrithviViT( img_size=img_size, num_frames=num_frames, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, coords_encoding=coords_encoding, coords_scale_learn=coords_scale_learn, drop_path=drop_path, ) self.decoder = MAEDecoder( patch_size=patch_size, grid_size=self.encoder.patch_embed.grid_size, in_chans=in_chans, encoder_embed_dim=embed_dim, decoder_embed_dim=decoder_embed_dim, depth=decoder_depth, num_heads=decoder_num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, coords_encoding=coords_encoding, coords_scale_learn=coords_scale_learn, ) self.mask_ratio = mask_ratio self.norm_pix_loss = norm_pix_loss self.out_channels = self.encoder.out_channels def patchify(self, pixel_values): """ Args: pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`): Pixel values. Returns: torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: Patchified pixel values. """ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size num_channels = self.encoder.in_chans # patchify patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w) return patchified_pixel_values def unpatchify(self, patchified_pixel_values, image_size: tuple[int, int] | None = None): """ Args: patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`: Patchified pixel values. image_size (`tuple[int, int]`, *optional*): Original image size. Returns: `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: Pixel values. """ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size original_height, original_width = image_size num_patches_h = original_height // patch_size_h num_patches_w = original_width // patch_size_w num_channels = self.encoder.in_chans pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)', c=num_channels, h=num_patches_h, w=num_patches_w, s=patch_size_t, p=patch_size_h, q=patch_size_w) return pixel_values def forward_loss(self, pixel_values, pred, mask): """ Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`): Pixel values. pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: Predicted pixel values. mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Tensor indicating which patches are masked (1) and which are not (0). Returns: `torch.FloatTensor`: Pixel reconstruction loss. """ target = self.patchify(pixel_values) if self.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.0e-6) ** 0.5 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss def forward( self, pixel_values: torch.Tensor, temporal_coords: None | torch.Tensor = None, location_coords: None | torch.Tensor = None, mask_ratio: float = None, ): if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1: # add time dim pixel_values = pixel_values.unsqueeze(2) mask_ratio = mask_ratio or self.mask_ratio latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio) pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape) loss = self.forward_loss(pixel_values, pred, mask) return loss, pred, mask def forward_features( self, x: torch.Tensor, temporal_coords: None | torch.Tensor = None, location_coords: None | torch.Tensor = None, ) -> list[torch.Tensor]: return self.encoder.forward_features(x, temporal_coords, location_coords)