MotionLCM / mld /models /operator /embeddings.py
wxDai's picture
[Init]
eb339cb
raw
history blame
3.2 kB
import math
from typing import Optional
import torch
import torch.nn as nn
from .utils import get_activation_fn
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
) -> torch.Tensor:
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str,
out_dim: Optional[int] = None, post_act_fn: Optional[str] = None,
cond_proj_dim: Optional[int] = None, zero_init_cond: bool = True) -> None:
super(TimestepEmbedding, self).__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
if zero_init_cond:
self.cond_proj.weight.data.fill_(0.0)
else:
self.cond_proj = None
self.act = get_activation_fn(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation_fn(post_act_fn)
def forward(self, sample: torch.Tensor, timestep_cond: Optional[torch.Tensor] = None) -> torch.Tensor:
if timestep_cond is not None:
sample = sample + self.cond_proj(timestep_cond)
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool,
downscale_freq_shift: float) -> None:
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift)
return t_emb