|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.checkpoint import checkpoint |
|
import numpy as np |
|
import math |
|
from itertools import repeat |
|
import collections.abc |
|
from .attention import MemEffCrossAttention, MemEffAttention |
|
from .utils import TimestepEmbedder, Mlp, modulate |
|
|
|
|
|
|
|
|
|
|
|
|
|
class DiTBlock(nn.Module): |
|
""" |
|
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. |
|
""" |
|
def __init__(self, hidden_size, cross_attn_cond_dim, num_heads, mlp_ratio=4.0, proj_bias=False, gradient_checkpointing=False, **block_kwargs): |
|
super().__init__() |
|
self.gradient_checkpointing = gradient_checkpointing |
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.crossattn = MemEffCrossAttention(dim=hidden_size, dim_q=hidden_size, dim_k=cross_attn_cond_dim, dim_v=cross_attn_cond_dim, num_heads=num_heads, qkv_bias=True, proj_bias=proj_bias, attn_drop=0.0, proj_drop=0.0, gradient_checkpointing=gradient_checkpointing, **block_kwargs) |
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.attn = MemEffAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=True, proj_bias=proj_bias, attn_drop=0.0, proj_drop=0.0, gradient_checkpointing=gradient_checkpointing, **block_kwargs) |
|
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
mlp_hidden_dim = int(hidden_size * mlp_ratio) |
|
approx_gelu = lambda: nn.GELU(approximate="tanh") |
|
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) |
|
self.adaLN_modulation = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 9 * hidden_size, bias=True) |
|
) |
|
|
|
def forward(self, x, cross_attn_cond, mod_cond): |
|
if self.training and self.gradient_checkpointing: |
|
return checkpoint(self._forward, x, cross_attn_cond, mod_cond, use_reentrant=False) |
|
else: |
|
return self._forward(x, cross_attn_cond, mod_cond) |
|
|
|
def _forward(self, x, cross_attn_cond, mod_cond): |
|
|
|
|
|
shift_mca, scale_mca, gate_mca, shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod_cond).chunk(9, dim=1) |
|
x = x + gate_mca.unsqueeze(1) * self.crossattn(modulate(self.norm1(x), shift_mca, scale_mca), cross_attn_cond, cross_attn_cond) |
|
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm2(x), shift_msa, scale_msa)) |
|
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm3(x), shift_mlp, scale_mlp)) |
|
return x |
|
|
|
|
|
class FinalLayer(nn.Module): |
|
""" |
|
The final layer of DiT. |
|
""" |
|
def __init__(self, hidden_size, seq_length, out_channels): |
|
super().__init__() |
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.linear = nn.Linear(hidden_size, out_channels, bias=True) |
|
self.adaLN_modulation = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True) |
|
) |
|
|
|
def forward(self, x, c): |
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) |
|
x = modulate(self.norm_final(x), shift, scale) |
|
x = self.linear(x) |
|
return x |
|
|
|
class PointEmbed(nn.Module): |
|
def __init__(self, hidden_dim=48, dim=128): |
|
super().__init__() |
|
|
|
assert hidden_dim % 6 == 0 |
|
|
|
self.embedding_dim = hidden_dim |
|
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi |
|
e = torch.stack([ |
|
torch.cat([e, torch.zeros(self.embedding_dim // 6), |
|
torch.zeros(self.embedding_dim // 6)]), |
|
torch.cat([torch.zeros(self.embedding_dim // 6), e, |
|
torch.zeros(self.embedding_dim // 6)]), |
|
torch.cat([torch.zeros(self.embedding_dim // 6), |
|
torch.zeros(self.embedding_dim // 6), e]), |
|
]) |
|
self.register_buffer('basis', e) |
|
|
|
self.mlp = nn.Linear(self.embedding_dim+3, dim) |
|
|
|
@staticmethod |
|
def embed(input, basis): |
|
projections = torch.einsum('bnd,de->bne', input, basis) |
|
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) |
|
return embeddings |
|
|
|
def forward(self, input): |
|
|
|
embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) |
|
return embed |
|
|
|
class DiT(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
def __init__( |
|
self, |
|
seq_length=2, |
|
in_channels=4, |
|
condition_channels=512, |
|
hidden_size=1152, |
|
depth=28, |
|
num_heads=16, |
|
mlp_ratio=4.0, |
|
cond_drop_prob=0.0, |
|
attn_proj_bias=False, |
|
learn_sigma=True, |
|
gradient_checkpointing=False, |
|
): |
|
super().__init__() |
|
self.gradient_checkpointing = gradient_checkpointing |
|
self.learn_sigma = learn_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels * 2 if learn_sigma else in_channels |
|
self.seq_length = seq_length |
|
self.num_heads = num_heads |
|
self.cond_drop_prob = cond_drop_prob |
|
if self.cond_drop_prob > 0: |
|
self.null_cond_embedding = nn.Parameter(torch.randn(condition_channels)) |
|
|
|
|
|
self.x_embedder = nn.Linear(in_channels, hidden_size) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
|
|
self.blocks = nn.ModuleList([ |
|
DiTBlock(hidden_size, condition_channels, num_heads, mlp_ratio=mlp_ratio, proj_bias=attn_proj_bias, gradient_checkpointing=gradient_checkpointing) for _ in range(depth) |
|
]) |
|
self.final_layer = FinalLayer(hidden_size, seq_length, self.out_channels) |
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
self.apply(_basic_init) |
|
|
|
|
|
|
|
|
|
|
|
|
|
w = self.x_embedder.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
nn.init.constant_(self.x_embedder.bias, 0) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
|
|
|
|
for block in self.blocks: |
|
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) |
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
def forward(self, x, t, y, precision_dtype=torch.float32, enable_amp=False): |
|
""" |
|
Forward pass of DiT. |
|
x: (N, T, D) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N,) tensor of class labels |
|
""" |
|
x = self.x_embedder(x) |
|
t = self.t_embedder(t) |
|
if self.cond_drop_prob > 0 and self.training: |
|
drop_mask = torch.rand(y.shape[0], device=y.device) < self.cond_drop_prob |
|
null_cond_embed = self.null_cond_embedding[None, None, :] |
|
y = torch.where(drop_mask[:, None, None], null_cond_embed, y) |
|
with torch.autocast(device_type='cuda', dtype=precision_dtype, enabled=enable_amp): |
|
for block in self.blocks: |
|
x = block(x=x, cross_attn_cond=y, mod_cond=t) |
|
|
|
x = self.final_layer(x, t) |
|
return x |
|
|
|
def forward_with_cfg(self, x, t, y, cfg_scale=0.0, precision_dtype=torch.float32, enable_amp=False): |
|
combined = torch.cat([x, x], dim=0) |
|
combined_t = torch.cat([t, t], dim=0) |
|
y_null = self.null_cond_embedding.expand_as(y) |
|
combined_y = torch.cat([y, y_null], dim=0) |
|
model_out = self.forward(combined, combined_t, combined_y, precision_dtype, enable_amp) |
|
eps = model_out |
|
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) |
|
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) |
|
return half_eps |
|
|
|
class DiTAdditivePosEmb(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
def __init__( |
|
self, |
|
seq_length=2, |
|
in_channels=4, |
|
condition_channels=512, |
|
hidden_size=1152, |
|
depth=28, |
|
num_heads=16, |
|
mlp_ratio=4.0, |
|
attn_proj_bias=False, |
|
learn_sigma=True, |
|
gradient_checkpointing=False, |
|
): |
|
super().__init__() |
|
self.gradient_checkpointing = gradient_checkpointing |
|
self.learn_sigma = learn_sigma |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels * 2 if learn_sigma else in_channels |
|
self.seq_length = seq_length |
|
self.num_heads = num_heads |
|
|
|
|
|
self.point_emb = PointEmbed(hidden_dim=48, dim=hidden_size) |
|
self.x_embedder = nn.Linear(in_channels, hidden_size) |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size) |
|
|
|
self.blocks = nn.ModuleList([ |
|
DiTBlock(hidden_size, condition_channels, num_heads, mlp_ratio=mlp_ratio, proj_bias=attn_proj_bias, gradient_checkpointing=gradient_checkpointing) for _ in range(depth) |
|
]) |
|
self.final_layer = FinalLayer(hidden_size, seq_length, self.out_channels) |
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
self.apply(_basic_init) |
|
|
|
|
|
|
|
|
|
|
|
|
|
w = self.x_embedder.weight.data |
|
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
nn.init.constant_(self.x_embedder.bias, 0) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
|
|
|
|
for block in self.blocks: |
|
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) |
|
|
|
|
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) |
|
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) |
|
nn.init.constant_(self.final_layer.linear.weight, 0) |
|
nn.init.constant_(self.final_layer.linear.bias, 0) |
|
|
|
def forward(self, x, t, y, precision_dtype=torch.float32, enable_amp=False): |
|
""" |
|
Forward pass of DiT. |
|
x: (N, T, D) |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N,) tensor of class labels |
|
""" |
|
point = x[:, :, 1:4] |
|
point_emb = self.point_emb(point) |
|
x = self.x_embedder(x) + point_emb |
|
t = self.t_embedder(t) |
|
with torch.autocast(device_type='cuda', dtype=precision_dtype, enabled=enable_amp): |
|
for block in self.blocks: |
|
x = block(x=x, cross_attn_cond=y, mod_cond=t) |
|
|
|
x = self.final_layer(x, t) |
|
return x |