blumenstiel commited on
Commit
eb07d17
·
1 Parent(s): 40d37ff

Update model code

Browse files
Files changed (2) hide show
  1. inference.py +1 -1
  2. prithvi_mae.py +146 -116
inference.py CHANGED
@@ -358,7 +358,7 @@ def main(
358
 
359
  model.to(device)
360
 
361
- state_dict = torch.load(checkpoint, map_location=device)
362
  # discard fixed pos_embedding weight
363
  for k in list(state_dict.keys()):
364
  if 'pos_embed' in k:
 
358
 
359
  model.to(device)
360
 
361
+ state_dict = torch.load(checkpoint, map_location=device, weights_only=True)
362
  # discard fixed pos_embedding weight
363
  for k in list(state_dict.keys()):
364
  if 'pos_embed' in k:
prithvi_mae.py CHANGED
@@ -17,9 +17,7 @@
17
  # transformers: https://github.com/huggingface/transformers
18
  # --------------------------------------------------------
19
 
20
- from functools import partial
21
- from typing import List, Tuple
22
-
23
  import logging
24
  import numpy as np
25
  import torch
@@ -28,6 +26,8 @@ from einops import rearrange
28
  from timm.layers import to_2tuple
29
  from timm.models.vision_transformer import Block
30
 
 
 
31
 
32
  def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
33
  """
@@ -91,11 +91,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
91
 
92
 
93
  def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
94
- """ This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However,
95
- it was modified to cast omega values to pos.dtype which must be float (and not int as in
96
- regular positional embeddings). This was required in order to allow for native FSDP mixed
97
- precision support: modify omega to appropriate dtype (pos carries the correct float dtype),
98
- instead of manually forcing float32.
99
 
100
  embed_dim: output dimension for each position
101
  pos: a list of positions to be encoded: size (M,) - must be float dtype!
@@ -130,12 +126,56 @@ def _init_weights(module):
130
  module.weight.data.fill_(1.0)
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  class PatchEmbed(nn.Module):
134
  """3D version of timm.models.vision_transformer.PatchEmbed"""
135
  def __init__(
136
  self,
137
- input_size: Tuple[int, int, int] = (1, 224, 224),
138
- patch_size: Tuple[int, int, int] = (1, 16, 16),
139
  in_chans: int = 3,
140
  embed_dim: int = 768,
141
  norm_layer: nn.Module | None = None,
@@ -146,6 +186,7 @@ class PatchEmbed(nn.Module):
146
  self.input_size = input_size
147
  self.patch_size = patch_size
148
  self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
 
149
  self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
150
  self.flatten = flatten
151
 
@@ -156,8 +197,8 @@ class PatchEmbed(nn.Module):
156
  B, C, T, H, W = x.shape
157
 
158
  if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
159
- logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
160
- f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
161
 
162
  x = self.proj(x)
163
  if self.flatten:
@@ -232,24 +273,22 @@ class LocationEncoder(nn.Module):
232
  class PrithviViT(nn.Module):
233
  """ Prithvi ViT Encoder"""
234
  def __init__(self,
235
- img_size: int | Tuple[int, int] = 224,
236
- patch_size: int | Tuple[int, int, int] = (1, 16, 16),
237
  num_frames: int = 1,
238
  in_chans: int = 3,
239
  embed_dim: int = 1024,
240
  depth: int = 24,
241
  num_heads: int = 16,
242
  mlp_ratio: float = 4.,
243
- norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
244
- coords_encoding: List[str] | None = None,
245
  coords_scale_learn: bool = False,
246
- encoder_only: bool = True, # needed for timm
247
  ** kwargs,
248
  ):
249
  super().__init__()
250
 
251
- self.feature_info = []
252
- self.encoder_only = encoder_only
253
  self.in_chans = in_chans
254
  self.num_frames = num_frames
255
  self.embed_dim = embed_dim
@@ -264,6 +303,7 @@ class PrithviViT(nn.Module):
264
  in_chans=in_chans,
265
  embed_dim=embed_dim,
266
  )
 
267
 
268
  # Optional temporal and location embedding
269
  coords_encoding = coords_encoding or []
@@ -281,10 +321,8 @@ class PrithviViT(nn.Module):
281
  # Transformer layers
282
  self.blocks = []
283
  for i in range(depth):
284
- self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer))
285
- self.feature_info.append(
286
- {"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"}
287
- )
288
  self.blocks = nn.ModuleList(self.blocks)
289
 
290
  self.norm = norm_layer(embed_dim)
@@ -339,45 +377,40 @@ class PrithviViT(nn.Module):
339
 
340
  return sequence_unmasked, mask, ids_restore
341
 
342
- def _get_pos_embed(self, x):
343
- t, h, w = x.shape[-3:]
344
-
345
- pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(
346
- self.embed_dim,
347
- (
348
- t // self.patch_embed.patch_size[0],
349
- h // self.patch_embed.patch_size[1],
350
- w // self.patch_embed.patch_size[2],
351
- ),
352
- add_cls_token=True,
353
- )).float().unsqueeze(0).to(x)
354
 
 
 
 
 
 
 
 
355
  return pos_embed
356
 
357
-
358
  def forward(
359
  self, x: torch.Tensor,
360
  temporal_coords: None | torch.Tensor = None,
361
  location_coords: None | torch.Tensor = None,
362
  mask_ratio=0.75
363
  ):
364
- if x.shape[-3:] != self.patch_embed.input_size:
365
- # changed input size
366
- pos_embed = self._get_pos_embed(x)
367
- else:
368
- pos_embed = self.pos_embed
369
 
370
  # embed patches
371
  x = self.patch_embed(x)
372
 
 
373
  # add pos embed w/o cls token
374
  x = x + pos_embed[:, 1:, :]
375
 
376
- if self.temporal_encoding:
377
  num_tokens_per_frame = x.shape[1] // self.num_frames
378
  temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
379
  x = x + temporal_encoding
380
- if self.location_encoding:
381
  location_encoding = self.location_embed_enc(location_coords)
382
  x = x + location_encoding
383
 
@@ -405,23 +438,20 @@ class PrithviViT(nn.Module):
405
  if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
406
  # add time dim
407
  x = x.unsqueeze(2)
408
-
409
- if x.shape[-3:] != self.patch_embed.input_size:
410
- pos_embed = self._get_pos_embed(x)
411
- else:
412
- pos_embed = self.pos_embed
413
 
414
  # embed patches
415
  x = self.patch_embed(x)
416
 
 
417
  # add pos embed w/o cls token
418
  x = x + pos_embed[:, 1:, :]
419
 
420
- if self.temporal_encoding:
421
- num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames
422
  temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
423
  x = x + temporal_encoding
424
- if self.location_encoding:
425
  location_encoding = self.location_embed_enc(location_coords)
426
  x = x + location_encoding
427
 
@@ -462,8 +492,8 @@ class PrithviViT(nn.Module):
462
  class MAEDecoder(nn.Module):
463
  """ Transformer Decoder used in the Prithvi MAE"""
464
  def __init__(self,
465
- patch_size: int | Tuple[int, int, int] = (1, 16, 16),
466
- grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14),
467
  in_chans: int = 3,
468
  encoder_embed_dim: int = 1024,
469
  decoder_embed_dim: int = 512,
@@ -471,7 +501,7 @@ class MAEDecoder(nn.Module):
471
  num_heads: int = 16,
472
  mlp_ratio: float = 4.,
473
  norm_layer: nn.Module = nn.LayerNorm,
474
- coords_encoding: List[str] | None = None,
475
  coords_scale_learn: bool = False,
476
  ):
477
  super().__init__()
@@ -520,6 +550,18 @@ class MAEDecoder(nn.Module):
520
  torch.nn.init.normal_(self.mask_token, std=0.02)
521
  self.apply(_init_weights)
522
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  def forward(
524
  self,
525
  hidden_states: torch.Tensor,
@@ -530,44 +572,31 @@ class MAEDecoder(nn.Module):
530
  ):
531
  # embed tokens
532
  x = self.decoder_embed(hidden_states)
533
-
534
- t, h, w = input_size[-3:]
535
- decoder_pos_embed = torch.from_numpy(
536
- get_3d_sincos_pos_embed(
537
- self.decoder_embed_dim,
538
- (
539
- t // self.patch_size[0],
540
- h // self.patch_size[1],
541
- w // self.patch_size[2],
542
- ),
543
- add_cls_token=True,
544
- )
545
- ).to(x)
546
 
547
  # append mask tokens to sequence
548
  mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
549
- x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
550
  # unshuffle
551
- x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
552
- x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
553
- # add pos embed
554
- x = x + decoder_pos_embed
555
 
556
- # remove cls token
557
- x_ = x[:, 1:, :]
 
 
558
 
559
- if self.temporal_encoding:
560
- num_tokens_per_frame = x_.shape[1] // self.num_frames
561
  temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
562
  # Add temporal encoding w/o cls token
563
- x_ = x_ + temporal_encoding
564
- if self.location_encoding:
565
  location_encoding = self.location_embed_dec(location_coords)
566
  # Add location encoding w/o cls token
567
- x_ = x_ + location_encoding
568
 
569
  # append cls token
570
- x = torch.cat([x[:, :1, :], x_], dim=1)
571
 
572
  # apply Transformer layers (blocks)
573
  for block in self.decoder_blocks:
@@ -587,22 +616,23 @@ class PrithviMAE(nn.Module):
587
  """ Prithvi Masked Autoencoder"""
588
 
589
  def __init__(self,
590
- img_size: int | Tuple[int, int] = 224,
591
- patch_size: int | Tuple[int, int, int] = (1, 16, 16),
592
- num_frames: int = 3,
593
- in_chans: int = 3,
594
- embed_dim: int = 1024,
595
- depth: int = 24,
596
- num_heads: int = 16,
597
  decoder_embed_dim: int = 512,
598
  decoder_depth: int = 8,
599
  decoder_num_heads: int = 16,
600
  mlp_ratio: float = 4.,
601
- norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
602
  norm_pix_loss: bool = False,
603
- coords_encoding: List[str] | None = None,
604
  coords_scale_learn: bool = False,
605
- encoder_only: bool = False,
 
606
  **kwargs,
607
  ):
608
  super().__init__()
@@ -619,28 +649,26 @@ class PrithviMAE(nn.Module):
619
  norm_layer=norm_layer,
620
  coords_encoding=coords_encoding,
621
  coords_scale_learn=coords_scale_learn,
 
622
  )
623
 
624
- self.encoder_only = encoder_only
625
-
626
- if not encoder_only:
627
- self.decoder = MAEDecoder(
628
- patch_size=patch_size,
629
- grid_size=self.encoder.patch_embed.grid_size,
630
- in_chans=in_chans,
631
- encoder_embed_dim=embed_dim,
632
- decoder_embed_dim=decoder_embed_dim,
633
- depth=decoder_depth,
634
- num_heads=decoder_num_heads,
635
- mlp_ratio=mlp_ratio,
636
- norm_layer=norm_layer,
637
- coords_encoding=coords_encoding,
638
- coords_scale_learn=coords_scale_learn,
639
- )
640
- else:
641
- self.decoder = nn.Identity()
642
 
 
643
  self.norm_pix_loss = norm_pix_loss
 
644
 
645
  def patchify(self, pixel_values):
646
  """
@@ -649,7 +677,8 @@ class PrithviMAE(nn.Module):
649
  Pixel values.
650
 
651
  Returns:
652
- torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
 
653
  Patchified pixel values.
654
  """
655
  patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
@@ -659,16 +688,15 @@ class PrithviMAE(nn.Module):
659
  patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
660
  c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
661
 
662
-
663
  return patchified_pixel_values
664
 
665
- def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None):
666
  """
667
  Args:
668
  patchified_pixel_values (`torch.FloatTensor` of shape
669
- `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
670
  Patchified pixel values.
671
- image_size (`Tuple[int, int]`, *optional*):
672
  Original image size.
673
 
674
  Returns:
@@ -692,7 +720,8 @@ class PrithviMAE(nn.Module):
692
  Args:
693
  pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
694
  Pixel values.
695
- pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
 
696
  Predicted pixel values.
697
  mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
698
  Tensor indicating which patches are masked (1) and which are not (0).
@@ -716,12 +745,13 @@ class PrithviMAE(nn.Module):
716
  pixel_values: torch.Tensor,
717
  temporal_coords: None | torch.Tensor = None,
718
  location_coords: None | torch.Tensor = None,
719
- mask_ratio: float = 0.75
720
  ):
721
  if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
722
  # add time dim
723
  pixel_values = pixel_values.unsqueeze(2)
724
 
 
725
  latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
726
  pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
727
  loss = self.forward_loss(pixel_values, pred, mask)
@@ -732,5 +762,5 @@ class PrithviMAE(nn.Module):
732
  x: torch.Tensor,
733
  temporal_coords: None | torch.Tensor = None,
734
  location_coords: None | torch.Tensor = None,
735
- ) -> List[torch.Tensor]:
736
  return self.encoder.forward_features(x, temporal_coords, location_coords)
 
17
  # transformers: https://github.com/huggingface/transformers
18
  # --------------------------------------------------------
19
 
20
+ import warnings
 
 
21
  import logging
22
  import numpy as np
23
  import torch
 
26
  from timm.layers import to_2tuple
27
  from timm.models.vision_transformer import Block
28
 
29
+ logger = logging.getLogger(__name__)
30
+
31
 
32
  def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
33
  """
 
91
 
92
 
93
  def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
94
+ """ Modified torch version of *get_1d_sincos_pos_embed_from_grid()*.
 
 
 
 
95
 
96
  embed_dim: output dimension for each position
97
  pos: a list of positions to be encoded: size (M,) - must be float dtype!
 
126
  module.weight.data.fill_(1.0)
127
 
128
 
129
+ def _interpolate_pos_encoding(
130
+ pos_embed: torch.Tensor,
131
+ grid_size: tuple[int, int, int] | list[int],
132
+ patch_size: tuple[int, int, int] | list[int],
133
+ shape: tuple[int, int, int],
134
+ embed_dim: int,
135
+ ):
136
+ """
137
+ Adapted from:
138
+ - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
139
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
140
+ """
141
+ t, h, w = shape
142
+ t_patches = t // patch_size[0]
143
+ h_patches = h // patch_size[1]
144
+ w_patches = w // patch_size[2]
145
+
146
+ if [t_patches, h_patches, w_patches] == grid_size:
147
+ # No interpolation needed
148
+ return pos_embed
149
+ if t_patches != grid_size[0]:
150
+ # Re-compute pos embedding to handle changed num_frames
151
+ new_grid_size = (t_patches, *grid_size[1:])
152
+ new_pos_embed = get_3d_sincos_pos_embed(pos_embed.shape[-1], new_grid_size, add_cls_token=True)
153
+ new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0)
154
+ else:
155
+ new_grid_size = grid_size
156
+ new_pos_embed = pos_embed
157
+
158
+ class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:]
159
+
160
+ patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(0, 3, 1, 2)
161
+
162
+ patch_pos_embed = nn.functional.interpolate(
163
+ patch_pos_embed,
164
+ size=(h_patches, w_patches),
165
+ mode='bicubic',
166
+ align_corners=True,
167
+ )
168
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
169
+
170
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
171
+
172
+
173
  class PatchEmbed(nn.Module):
174
  """3D version of timm.models.vision_transformer.PatchEmbed"""
175
  def __init__(
176
  self,
177
+ input_size: tuple[int, int, int] = (1, 224, 224),
178
+ patch_size: tuple[int, int, int] = (1, 16, 16),
179
  in_chans: int = 3,
180
  embed_dim: int = 768,
181
  norm_layer: nn.Module | None = None,
 
186
  self.input_size = input_size
187
  self.patch_size = patch_size
188
  self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
189
+ assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size."
190
  self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
191
  self.flatten = flatten
192
 
 
197
  B, C, T, H, W = x.shape
198
 
199
  if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
200
+ warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
201
+ f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
202
 
203
  x = self.proj(x)
204
  if self.flatten:
 
273
  class PrithviViT(nn.Module):
274
  """ Prithvi ViT Encoder"""
275
  def __init__(self,
276
+ img_size: int | tuple[int, int] = 224,
277
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
278
  num_frames: int = 1,
279
  in_chans: int = 3,
280
  embed_dim: int = 1024,
281
  depth: int = 24,
282
  num_heads: int = 16,
283
  mlp_ratio: float = 4.,
284
+ norm_layer: nn.Module = nn.LayerNorm,
285
+ coords_encoding: list[str] | None = None,
286
  coords_scale_learn: bool = False,
287
+ drop_path: float = 0.,
288
  ** kwargs,
289
  ):
290
  super().__init__()
291
 
 
 
292
  self.in_chans = in_chans
293
  self.num_frames = num_frames
294
  self.embed_dim = embed_dim
 
303
  in_chans=in_chans,
304
  embed_dim=embed_dim,
305
  )
306
+ self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth
307
 
308
  # Optional temporal and location embedding
309
  coords_encoding = coords_encoding or []
 
321
  # Transformer layers
322
  self.blocks = []
323
  for i in range(depth):
324
+ self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
325
+ drop_path=drop_path,))
 
 
326
  self.blocks = nn.ModuleList(self.blocks)
327
 
328
  self.norm = norm_layer(embed_dim)
 
377
 
378
  return sequence_unmasked, mask, ids_restore
379
 
380
+ def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):
 
 
 
 
 
 
 
 
 
 
 
381
 
382
+ pos_embed = _interpolate_pos_encoding(
383
+ pos_embed=self.pos_embed,
384
+ grid_size=self.patch_embed.grid_size,
385
+ patch_size=self.patch_embed.patch_size,
386
+ shape=sample_shape,
387
+ embed_dim=self.embed_dim,
388
+ )
389
  return pos_embed
390
 
 
391
  def forward(
392
  self, x: torch.Tensor,
393
  temporal_coords: None | torch.Tensor = None,
394
  location_coords: None | torch.Tensor = None,
395
  mask_ratio=0.75
396
  ):
397
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
398
+ # add time dim
399
+ x = x.unsqueeze(2)
400
+ sample_shape = x.shape[-3:]
 
401
 
402
  # embed patches
403
  x = self.patch_embed(x)
404
 
405
+ pos_embed = self.interpolate_pos_encoding(sample_shape)
406
  # add pos embed w/o cls token
407
  x = x + pos_embed[:, 1:, :]
408
 
409
+ if self.temporal_encoding and temporal_coords is not None:
410
  num_tokens_per_frame = x.shape[1] // self.num_frames
411
  temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
412
  x = x + temporal_encoding
413
+ if self.location_encoding and location_coords is not None:
414
  location_encoding = self.location_embed_enc(location_coords)
415
  x = x + location_encoding
416
 
 
438
  if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
439
  # add time dim
440
  x = x.unsqueeze(2)
441
+ sample_shape = x.shape[-3:]
 
 
 
 
442
 
443
  # embed patches
444
  x = self.patch_embed(x)
445
 
446
+ pos_embed = self.interpolate_pos_encoding(sample_shape)
447
  # add pos embed w/o cls token
448
  x = x + pos_embed[:, 1:, :]
449
 
450
+ if self.temporal_encoding and temporal_coords is not None:
451
+ num_tokens_per_frame = x.shape[1] // self.num_frames
452
  temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
453
  x = x + temporal_encoding
454
+ if self.location_encoding and location_coords is not None:
455
  location_encoding = self.location_embed_enc(location_coords)
456
  x = x + location_encoding
457
 
 
492
  class MAEDecoder(nn.Module):
493
  """ Transformer Decoder used in the Prithvi MAE"""
494
  def __init__(self,
495
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
496
+ grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
497
  in_chans: int = 3,
498
  encoder_embed_dim: int = 1024,
499
  decoder_embed_dim: int = 512,
 
501
  num_heads: int = 16,
502
  mlp_ratio: float = 4.,
503
  norm_layer: nn.Module = nn.LayerNorm,
504
+ coords_encoding: list[str] | None = None,
505
  coords_scale_learn: bool = False,
506
  ):
507
  super().__init__()
 
550
  torch.nn.init.normal_(self.mask_token, std=0.02)
551
  self.apply(_init_weights)
552
 
553
+ def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):
554
+
555
+ pos_embed = _interpolate_pos_encoding(
556
+ pos_embed=self.decoder_pos_embed,
557
+ grid_size=self.grid_size,
558
+ patch_size=self.patch_size,
559
+ shape=sample_shape,
560
+ embed_dim=self.decoder_embed_dim,
561
+ )
562
+
563
+ return pos_embed
564
+
565
  def forward(
566
  self,
567
  hidden_states: torch.Tensor,
 
572
  ):
573
  # embed tokens
574
  x = self.decoder_embed(hidden_states)
575
+ cls_token = x[:, :1, :]
 
 
 
 
 
 
 
 
 
 
 
 
576
 
577
  # append mask tokens to sequence
578
  mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
579
+ x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
580
  # unshuffle
581
+ x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device))
 
 
 
582
 
583
+ # add pos embed
584
+ decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:])
585
+ cls_token = cls_token + decoder_pos_embed[:, :1, :]
586
+ x = x + decoder_pos_embed[:, 1:, :]
587
 
588
+ if self.temporal_encoding and temporal_coords is not None:
589
+ num_tokens_per_frame = x.shape[1] // self.num_frames
590
  temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
591
  # Add temporal encoding w/o cls token
592
+ x = x + temporal_encoding
593
+ if self.location_encoding and location_coords is not None:
594
  location_encoding = self.location_embed_dec(location_coords)
595
  # Add location encoding w/o cls token
596
+ x = x + location_encoding
597
 
598
  # append cls token
599
+ x = torch.cat([cls_token, x], dim=1)
600
 
601
  # apply Transformer layers (blocks)
602
  for block in self.decoder_blocks:
 
616
  """ Prithvi Masked Autoencoder"""
617
 
618
  def __init__(self,
619
+ img_size: int | tuple[int, int] = 224,
620
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
621
+ num_frames: int = 4,
622
+ in_chans: int = 6,
623
+ embed_dim: int = 768,
624
+ depth: int = 12,
625
+ num_heads: int = 12,
626
  decoder_embed_dim: int = 512,
627
  decoder_depth: int = 8,
628
  decoder_num_heads: int = 16,
629
  mlp_ratio: float = 4.,
630
+ norm_layer: nn.Module = nn.LayerNorm,
631
  norm_pix_loss: bool = False,
632
+ coords_encoding: list[str] | None = None,
633
  coords_scale_learn: bool = False,
634
+ drop_path: float = 0.,
635
+ mask_ratio: float = 0.75,
636
  **kwargs,
637
  ):
638
  super().__init__()
 
649
  norm_layer=norm_layer,
650
  coords_encoding=coords_encoding,
651
  coords_scale_learn=coords_scale_learn,
652
+ drop_path=drop_path,
653
  )
654
 
655
+ self.decoder = MAEDecoder(
656
+ patch_size=patch_size,
657
+ grid_size=self.encoder.patch_embed.grid_size,
658
+ in_chans=in_chans,
659
+ encoder_embed_dim=embed_dim,
660
+ decoder_embed_dim=decoder_embed_dim,
661
+ depth=decoder_depth,
662
+ num_heads=decoder_num_heads,
663
+ mlp_ratio=mlp_ratio,
664
+ norm_layer=norm_layer,
665
+ coords_encoding=coords_encoding,
666
+ coords_scale_learn=coords_scale_learn,
667
+ )
 
 
 
 
 
668
 
669
+ self.mask_ratio = mask_ratio
670
  self.norm_pix_loss = norm_pix_loss
671
+ self.out_channels = self.encoder.out_channels
672
 
673
  def patchify(self, pixel_values):
674
  """
 
677
  Pixel values.
678
 
679
  Returns:
680
+ torch.FloatTensor of shape
681
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
682
  Patchified pixel values.
683
  """
684
  patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
 
688
  patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
689
  c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
690
 
 
691
  return patchified_pixel_values
692
 
693
+ def unpatchify(self, patchified_pixel_values, image_size: tuple[int, int] | None = None):
694
  """
695
  Args:
696
  patchified_pixel_values (`torch.FloatTensor` of shape
697
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`:
698
  Patchified pixel values.
699
+ image_size (`tuple[int, int]`, *optional*):
700
  Original image size.
701
 
702
  Returns:
 
720
  Args:
721
  pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
722
  Pixel values.
723
+ pred (`torch.FloatTensor` of shape
724
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
725
  Predicted pixel values.
726
  mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
727
  Tensor indicating which patches are masked (1) and which are not (0).
 
745
  pixel_values: torch.Tensor,
746
  temporal_coords: None | torch.Tensor = None,
747
  location_coords: None | torch.Tensor = None,
748
+ mask_ratio: float = None,
749
  ):
750
  if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
751
  # add time dim
752
  pixel_values = pixel_values.unsqueeze(2)
753
 
754
+ mask_ratio = mask_ratio or self.mask_ratio
755
  latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
756
  pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
757
  loss = self.forward_loss(pixel_values, pred, mask)
 
762
  x: torch.Tensor,
763
  temporal_coords: None | torch.Tensor = None,
764
  location_coords: None | torch.Tensor = None,
765
+ ) -> list[torch.Tensor]:
766
  return self.encoder.forward_features(x, temporal_coords, location_coords)