michellemoorre's picture
Initial commit
6c4dee3
raw
history blame
15.9 kB
import math
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import dist
from models.basic_var import AdaLNBeforeHead, AdaLNSelfCrossAttn
from models.clip import FrozenCLIPEmbedder
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
from models.rope import compute_axial_cis
from models.vqvae import VQVAE, VectorQuantizer2
class SharedAdaLin(nn.Linear):
def forward(self, cond_BD):
C = self.weight.shape[0] // 6
return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
class VAR(nn.Module):
def __init__(
self,
rope=False,
rope_theta=100,
rope_size=None,
depth=16,
embed_dim=1024,
num_heads=16,
mlp_ratio=4.0,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_eps=1e-6,
shared_aln=False,
attn_l2_norm=False,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
fused_if_available=True,
use_swiglu_ffn=False,
Cvae=32,
V=4096
):
super().__init__()
# 0. hyperparameters
assert embed_dim % num_heads == 0
self.depth, self.C, self.D, self.num_heads = (
depth,
embed_dim,
embed_dim,
num_heads,
)
self.Cvae, self.V = Cvae, V
self.prog_si = -1 # progressive training
self.patch_nums: Tuple[int] = patch_nums
self.L = sum(pn**2 for pn in self.patch_nums)
self.first_l = self.patch_nums[0] ** 2
self.rope = rope
self.num_stages_minus_1 = len(self.patch_nums) - 1
self.rng = torch.Generator(device=dist.get_device())
# 1. input (word) embedding
self.word_embed = nn.Linear(self.Cvae, self.C)
# 2. text embedding
self.pooled_embed_size = 1280
context_dim = 1280 + 768
self.text_pooler = nn.Linear(self.pooled_embed_size, self.D)
init_std = math.sqrt(1 / self.C / 3)
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
# 3. position embedding
if not self.rope:
# absolute position embedding
pos_1LC = []
for i, pn in enumerate(self.patch_nums):
pe = torch.empty(1, pn * pn, self.C)
nn.init.trunc_normal_(pe, mean=0, std=init_std)
pos_1LC.append(pe)
pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
assert tuple(pos_1LC.shape) == (1, self.L, self.C)
self.pos_1LC = nn.Parameter(pos_1LC)
self.freqs_cis = None
else:
# RoPE position embedding
assert (
self.C // self.num_heads
) % 4 == 0, "2d rope needs head dim to be divisible by 4"
patch_nums_m1 = tuple(pn - 1 if pn > 1 else 1 for pn in self.patch_nums)
self.compute_cis = partial(compute_axial_cis, dim=self.C // self.num_heads)
freqs_cis = []
for i, pn in enumerate(self.patch_nums):
norm_coeff = rope_size / patch_nums_m1[i]
cur_freqs = self.compute_cis(
end_x=pn, end_y=pn, theta=rope_theta, norm_coeff=norm_coeff
)
freqs_cis.append(cur_freqs[None, ...])
self.freqs_cis = torch.cat(freqs_cis, dim=1) # 1, L, C // 2 -- complex
# level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
# 4. backbone blocks
self.shared_ada_lin = (
nn.Sequential(nn.SiLU(inplace=False), SharedAdaLin(self.D, 6 * self.C))
if shared_aln
else nn.Identity()
)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
self.drop_path_rate = drop_path_rate
# stochastic depth decay rule (linearly increasing)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([])
for block_idx in range(depth):
self.blocks.append(
AdaLNSelfCrossAttn(
cond_dim=self.D,
shared_aln=shared_aln,
block_idx=block_idx,
embed_dim=self.C,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[block_idx],
last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1],
qk_norm=attn_l2_norm,
context_dim=context_dim,
use_swiglu_ffn=use_swiglu_ffn,
norm_eps=norm_eps,
)
)
fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
self.using_fused_add_norm_fn = any(fused_add_norm_fns)
print(
f"\n[constructor] ==== fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n"
f" [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n"
f" [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})",
end="\n\n",
flush=True,
)
# 5. attention mask used in training (for masking out the future)
# it won't be used in inference, since kv cache is enabled
d: torch.Tensor = torch.cat(
[torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)]
).view(1, self.L, 1)
dT = d.transpose(1, 2) # dT: 11L
lvl_1L = dT[:, 0].contiguous()
self.register_buffer("lvl_1L", lvl_1L)
attn_bias_for_masking = torch.where(d >= dT, 0.0, -torch.inf).reshape(
1, 1, self.L, self.L
)
self.register_buffer(
"attn_bias_for_masking", attn_bias_for_masking.contiguous()
)
# 6. classifier head
self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
self.head = nn.Linear(self.C, self.V)
# By defailt disable gradient checkpointing
self.use_gradient_checkpointing = False
def enable_gradient_checkpointing(self):
self.use_gradient_checkpointing = True
def disable_gradient_checkpointing(self):
self.use_gradient_checkpointing = False
def get_logits(
self,
h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
cond_BD: Optional[torch.Tensor],
):
if not isinstance(h_or_h_and_residual, torch.Tensor):
h, resi = h_or_h_and_residual # fused_add_norm must be used
h = resi + self.blocks[-1].drop_path(h)
else: # fused_add_norm is not used
h = h_or_h_and_residual
return self.head(self.head_nm(h.float(), cond_BD).float()).float()
def parse_batch(self, batch, null_batch=None):
embedding_1 = batch["vit_l_14_text_embeddings"]
embedding_2 = batch["vit_bigg_14_text_embeddings"]
attention_mask = batch["vit_bigg_14_text_mask"]
batch_size = embedding_1.size(0)
prompt_embed = torch.concat([embedding_1, embedding_2], dim=-1)
prompt_lens = attention_mask.sum(dim=-1).to(int)
pooled_output = embedding_2[
torch.arange(batch_size, device=embedding_2.device), prompt_lens - 1
]
attention_bias = attention_mask.clone()
attention_bias[attention_mask == 0] = -float("inf")
attention_bias[attention_mask == 1] = 0.0
if null_batch is not None:
B, L, hidden_dim = prompt_embed.shape
pooled_dim = pooled_output.shape[1]
null_context = null_batch['prompt_embed']
null_pooled_embed = null_batch['pooled_embed']
null_attn_bias = null_batch['attn_bias']
null_context = null_context[:, :L].expand(B, L, hidden_dim).to(prompt_embed.device)
null_pooled_embed = null_pooled_embed.expand(B, pooled_dim).to(pooled_output.device)
null_attn_bias = null_attn_bias[:, :L].expand(B, L).to(attention_bias.device)
prompt_embed = torch.cat([prompt_embed, null_context], dim=0)
pooled_output = torch.cat([pooled_output, null_pooled_embed], dim=0)
attention_bias = torch.cat([attention_bias, null_attn_bias], dim=0)
return (
prompt_embed.to(dist.get_device()),
pooled_output.to(dist.get_device()),
attention_bias.to(dist.get_device()),
)
def forward(
self,
x_BLCv_wo_first_l: torch.Tensor,
prompt_embeds: torch.Tensor,
pooled_prompt_embeds: torch.Tensor,
prompt_attn_bias: torch.Tensor,
) -> torch.Tensor: # returns logits_BLV
"""
:param batch: {'image': not used in forward,
'text': image caption,
'vit_l_14_text_embeddings': text embedding from CLIP-ViT-L-14
'vit_bigg_14_text_embeddings': text embedding from CLIP-ViT-Big-G-14
'vit_bigg_14_text_mask': attention mask to get a correct pooled embedding
:param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
:return: logits BLV, V is vocab_size
"""
bg, ed = 0, self.L
B = x_BLCv_wo_first_l.shape[0]
with torch.amp.autocast('cuda', enabled=False):
pooled_prompt_embeds = self.text_pooler(pooled_prompt_embeds)
sos = cond_BD = pooled_prompt_embeds
sos = sos.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(
B, self.first_l, -1
)
x_BLC = torch.cat(
(sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1
)
x_BLC += self.lvl_embed(
self.lvl_1L[:, :ed].expand(B, -1)
) # lvl: BLC; pos: 1LC
if not self.rope:
x_BLC += self.pos_1LC[:, :ed]
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
# hack: get the dtype if mixed precision is used
temp = x_BLC.new_ones(8, 8)
main_type = torch.matmul(temp, temp).dtype
x_BLC = x_BLC.to(dtype=main_type)
cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
attn_bias = attn_bias.to(dtype=main_type)
for block in self.blocks:
if self.use_gradient_checkpointing:
x_BLC = torch.utils.checkpoint.checkpoint(
block,
x=x_BLC,
cond_BD=cond_BD_or_gss,
attn_bias=attn_bias,
context=prompt_embeds,
freqs_cis=self.freqs_cis,
context_attn_bias=prompt_attn_bias,
use_reentrant=False,
)
else:
x_BLC = block(
x=x_BLC,
cond_BD=cond_BD_or_gss,
attn_bias=attn_bias,
context=prompt_embeds,
freqs_cis=self.freqs_cis,
context_attn_bias=prompt_attn_bias,
)
with torch.amp.autocast('cuda', enabled=not self.training):
x_BLC = self.get_logits(x_BLC.float(), cond_BD)
return x_BLC # logits BLV, V is vocab_size
def init_weights(
self,
init_adaln=0.5,
init_adaln_gamma=1e-5,
init_head=0.02,
init_std=0.02,
):
if init_std < 0:
init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
print(f"[init_weights] {type(self).__name__} with {init_std=:g}")
for m in self.modules():
with_weight = hasattr(m, "weight") and m.weight is not None
with_bias = hasattr(m, "bias") and m.bias is not None
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight.data, std=init_std)
if with_bias:
m.bias.data.zero_()
elif isinstance(m, nn.Embedding):
nn.init.trunc_normal_(m.weight.data, std=init_std)
if m.padding_idx is not None:
m.weight.data[m.padding_idx].zero_()
elif isinstance(
m,
(
nn.LayerNorm,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.SyncBatchNorm,
nn.GroupNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
),
):
if with_weight:
m.weight.data.fill_(1.0)
if with_bias:
m.bias.data.zero_()
if init_head >= 0:
if isinstance(self.head, nn.Linear):
self.head.weight.data.mul_(init_head)
self.head.bias.data.zero_()
elif isinstance(self.head, nn.Sequential):
self.head[-1].weight.data.mul_(init_head)
self.head[-1].bias.data.zero_()
if isinstance(self.head_nm, AdaLNBeforeHead):
self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
if (
hasattr(self.head_nm.ada_lin[-1], "bias")
and self.head_nm.ada_lin[-1].bias is not None
):
self.head_nm.ada_lin[-1].bias.data.zero_()
depth = len(self.blocks)
for block in self.blocks:
block.attn.proj.weight.data.div_(math.sqrt(2 * depth))
block.cross_attn.proj.weight.data.div_(math.sqrt(2 * depth))
if hasattr(block.ffn, "fc2"):
block.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
if hasattr(block, "ada_lin"):
block.ada_lin[-1].weight.data[2 * self.C :].mul_(init_adaln)
block.ada_lin[-1].weight.data[: 2 * self.C].mul_(init_adaln_gamma)
if (
hasattr(block.ada_lin[-1], "bias")
and block.ada_lin[-1].bias is not None
):
block.ada_lin[-1].bias.data.zero_()
elif hasattr(block, "ada_gss"):
block.ada_gss.data[:, :, 2:].mul_(init_adaln)
block.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
def extra_repr(self):
return f"drop_path_rate={self.drop_path_rate:g}"
class TVARHF(VAR, PyTorchModelHubMixin):
# tags=["image-generation"]):
def __init__(
self,
depth=30,
shared_aln=False,
attn_l2_norm=True,
rope=True,
rope_theta=10000,
rope_size=128,
use_swiglu_ffn=True,
):
heads = depth
width = depth * 64
super().__init__(
depth=depth,
embed_dim=width,
num_heads=heads,
drop_rate=0.0,
attn_drop_rate=0.0,
norm_eps=1e-6,
shared_aln=shared_aln,
attn_l2_norm=attn_l2_norm,
patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),
rope=rope,
rope_theta=rope_theta,
rope_size=rope_size,
use_swiglu_ffn=use_swiglu_ffn,
)