|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from einops import rearrange, repeat |
|
from torch import nn |
|
|
|
from cosmos1.models.diffusion.module.attention import normalize |
|
from cosmos1.models.diffusion.module.timm import trunc_normal_ |
|
|
|
|
|
class VideoPositionEmb(nn.Module): |
|
def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor]) -> torch.Tensor: |
|
""" |
|
It delegates the embedding generation to generate_embeddings function. |
|
""" |
|
B_T_H_W_C = x_B_T_H_W_C.shape |
|
embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps) |
|
|
|
return embeddings |
|
|
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]): |
|
raise NotImplementedError |
|
|
|
|
|
class VideoRopePosition3DEmb(VideoPositionEmb): |
|
def __init__( |
|
self, |
|
*, |
|
head_dim: int, |
|
len_h: int, |
|
len_w: int, |
|
len_t: int, |
|
base_fps: int = 24, |
|
h_extrapolation_ratio: float = 1.0, |
|
w_extrapolation_ratio: float = 1.0, |
|
t_extrapolation_ratio: float = 1.0, |
|
**kwargs, |
|
): |
|
del kwargs |
|
super().__init__() |
|
self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float)) |
|
self.base_fps = base_fps |
|
self.max_h = len_h |
|
self.max_w = len_w |
|
|
|
dim = head_dim |
|
dim_h = dim // 6 * 2 |
|
dim_w = dim_h |
|
dim_t = dim - 2 * dim_h |
|
assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" |
|
self.register_buffer( |
|
"dim_spatial_range", |
|
torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h, |
|
persistent=False, |
|
) |
|
self.register_buffer( |
|
"dim_temporal_range", |
|
torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t, |
|
persistent=False, |
|
) |
|
|
|
self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) |
|
self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) |
|
self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) |
|
|
|
def generate_embeddings( |
|
self, |
|
B_T_H_W_C: torch.Size, |
|
fps: Optional[torch.Tensor] = None, |
|
h_ntk_factor: Optional[float] = None, |
|
w_ntk_factor: Optional[float] = None, |
|
t_ntk_factor: Optional[float] = None, |
|
): |
|
""" |
|
Generate embeddings for the given input size. |
|
|
|
Args: |
|
B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). |
|
fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. |
|
h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. |
|
w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. |
|
t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. |
|
|
|
Returns: |
|
Not specified in the original code snippet. |
|
""" |
|
h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor |
|
w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor |
|
t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor |
|
|
|
h_theta = 10000.0 * h_ntk_factor |
|
w_theta = 10000.0 * w_ntk_factor |
|
t_theta = 10000.0 * t_ntk_factor |
|
|
|
h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) |
|
w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) |
|
temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) |
|
|
|
B, T, H, W, _ = B_T_H_W_C |
|
uniform_fps = (fps is None) or (fps.min() == fps.max()) |
|
assert ( |
|
uniform_fps or B == 1 or T == 1 |
|
), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" |
|
assert ( |
|
H <= self.max_h and W <= self.max_w |
|
), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" |
|
half_emb_h = torch.outer(self.seq[:H], h_spatial_freqs) |
|
half_emb_w = torch.outer(self.seq[:W], w_spatial_freqs) |
|
|
|
|
|
if fps is None: |
|
assert T == 1, "T should be 1 for image batch." |
|
half_emb_t = torch.outer(self.seq[:T], temporal_freqs) |
|
else: |
|
half_emb_t = torch.outer(self.seq[:T] / fps[:1] * self.base_fps, temporal_freqs) |
|
|
|
em_T_H_W_D = torch.cat( |
|
[ |
|
repeat(half_emb_t, "t d -> t h w d", h=H, w=W), |
|
repeat(half_emb_h, "h d -> t h w d", t=T, w=W), |
|
repeat(half_emb_w, "w d -> t h w d", t=T, h=H), |
|
] |
|
* 2, |
|
dim=-1, |
|
) |
|
|
|
return rearrange(em_T_H_W_D, "t h w d -> (t h w) 1 1 d").float() |
|
|
|
|
|
class LearnablePosEmbAxis(VideoPositionEmb): |
|
def __init__( |
|
self, |
|
*, |
|
interpolation: str, |
|
model_channels: int, |
|
len_h: int, |
|
len_w: int, |
|
len_t: int, |
|
**kwargs, |
|
): |
|
""" |
|
Args: |
|
interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. |
|
""" |
|
del kwargs |
|
super().__init__() |
|
self.interpolation = interpolation |
|
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" |
|
|
|
self.pos_emb_h = nn.Parameter(torch.zeros(len_h, model_channels)) |
|
self.pos_emb_w = nn.Parameter(torch.zeros(len_w, model_channels)) |
|
self.pos_emb_t = nn.Parameter(torch.zeros(len_t, model_channels)) |
|
|
|
trunc_normal_(self.pos_emb_h, std=0.02) |
|
trunc_normal_(self.pos_emb_w, std=0.02) |
|
trunc_normal_(self.pos_emb_t, std=0.02) |
|
|
|
def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor]) -> torch.Tensor: |
|
B, T, H, W, _ = B_T_H_W_C |
|
if self.interpolation == "crop": |
|
emb_h_H = self.pos_emb_h[:H] |
|
emb_w_W = self.pos_emb_w[:W] |
|
emb_t_T = self.pos_emb_t[:T] |
|
emb = ( |
|
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) |
|
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) |
|
+ repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) |
|
) |
|
assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" |
|
else: |
|
raise ValueError(f"Unknown interpolation method {self.interpolation}") |
|
|
|
return normalize(emb, dim=-1, eps=1e-6) |
|
|