|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.distributed as dist |
|
import torch.nn.functional as F |
|
from typing import Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torchvision |
|
import torchvision.utils |
|
from diffusers.models.embeddings import Timesteps, TimestepEmbedding |
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm |
|
|
|
|
|
class ImageHead(nn.Module): |
|
|
|
def __init__(self, decoder_cfg, gpt_cfg, layer_id=None): |
|
super().__init__() |
|
self.layer_id = layer_id |
|
cfg = ( |
|
AttrDict( |
|
norm_type="layernorm", |
|
is_exp_norm=False, |
|
sequence_parallel=False, |
|
use_userbuffer=False, |
|
norm_eps=1e-5, |
|
norm_bias=True, |
|
gradient_accumulation_fusion=True, |
|
use_fp32_head_weight=False, |
|
) |
|
+ gpt_cfg |
|
) |
|
group = PG.tensor_parallel_group() |
|
assert cfg.norm_type in [ |
|
"layernorm", |
|
"rmsnorm", |
|
], f"Norm type:{cfg.norm_type} not supported" |
|
if cfg.norm_type == "rmsnorm": |
|
self.norm = DropoutAddRMSNorm( |
|
cfg.n_embed, |
|
prenorm=False, |
|
eps=cfg.norm_eps, |
|
is_exp_norm=cfg.is_exp_norm, |
|
sequence_parallel=cfg.sequence_parallel, |
|
) |
|
else: |
|
self.norm = DropoutAddLayerNorm( |
|
cfg.n_embed, |
|
prenorm=False, |
|
eps=cfg.norm_eps, |
|
is_exp_norm=cfg.is_exp_norm, |
|
sequence_parallel=cfg.sequence_parallel, |
|
bias=cfg.norm_bias, |
|
) |
|
|
|
multiple_of = 256 |
|
if decoder_cfg.in_channels % multiple_of != 0: |
|
warnings.warn( |
|
f"建议把 vocab_size 设置为 {multiple_of} 的倍数, 否则会影响矩阵乘法的性能" |
|
) |
|
|
|
dtype = default_dtype = torch.get_default_dtype() |
|
if cfg.use_fp32_head_weight: |
|
dtype = torch.float32 |
|
print( |
|
"使用 fp32 head weight!!!! 与原来的 bf16 head weight 不兼容\n", |
|
end="", |
|
flush=True, |
|
) |
|
torch.set_default_dtype(dtype) |
|
self.head = ColumnParallelLinear( |
|
cfg.n_embed, |
|
decoder_cfg.in_channels, |
|
bias=True, |
|
group=group, |
|
sequence_parallel=cfg.sequence_parallel, |
|
use_userbuffer=cfg.use_userbuffer, |
|
gradient_accumulation_fusion=cfg.gradient_accumulation_fusion, |
|
use_fp32_output=False, |
|
) |
|
torch.set_default_dtype(default_dtype) |
|
|
|
self.use_fp32_head_weight = cfg.use_fp32_head_weight |
|
|
|
def forward( |
|
self, input_args, images_split_mask: Optional[torch.BoolTensor] = None, **kwargs |
|
): |
|
residual = None |
|
if isinstance(input_args, tuple): |
|
x, residual = input_args |
|
else: |
|
x = input_args |
|
|
|
x = self.norm(x, residual) |
|
|
|
if self.use_fp32_head_weight: |
|
assert ( |
|
self.head.weight.dtype == torch.float32 |
|
), f"head.weight is {self.head.weight.dtype}" |
|
x = x.float() |
|
|
|
if images_split_mask is None: |
|
logits = self.head(x) |
|
else: |
|
bs, n_images = images_split_mask.shape[:2] |
|
n_embed = x.shape[-1] |
|
|
|
images_embed = torch.masked_select( |
|
x.unsqueeze(1), images_split_mask.unsqueeze(-1) |
|
) |
|
images_embed = images_embed.view((bs * n_images, -1, n_embed)) |
|
logits = self.head(images_embed) |
|
|
|
return logits |
|
|
|
|
|
class GlobalResponseNorm(nn.Module): |
|
|
|
def __init__(self, dim): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
|
|
def forward(self, x): |
|
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) |
|
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) |
|
|
|
return torch.addcmul(self.bias, (self.weight * nx + 1), x, value=1) |
|
|
|
|
|
class Downsample2D(nn.Module): |
|
"""A 2D downsampling layer with an optional convolution. |
|
|
|
Parameters: |
|
channels (`int`): |
|
number of channels in the inputs and outputs. |
|
use_conv (`bool`, default `False`): |
|
option to use a convolution. |
|
out_channels (`int`, optional): |
|
number of output channels. Defaults to `channels`. |
|
padding (`int`, default `1`): |
|
padding for the convolution. |
|
name (`str`, default `conv`): |
|
name of the downsampling 2D layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
channels: int, |
|
use_conv: bool = False, |
|
out_channels: Optional[int] = None, |
|
padding: int = 1, |
|
name: str = "conv", |
|
kernel_size=3, |
|
stride=2, |
|
norm_type=None, |
|
eps=None, |
|
elementwise_affine=None, |
|
bias=True, |
|
): |
|
super().__init__() |
|
self.channels = channels |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
self.padding = padding |
|
self.name = name |
|
|
|
if norm_type == "ln_norm": |
|
self.norm = nn.LayerNorm(channels, eps, elementwise_affine) |
|
elif norm_type == "rms_norm": |
|
self.norm = RMSNorm(channels, eps) |
|
elif norm_type is None: |
|
self.norm = None |
|
else: |
|
raise ValueError(f"unknown norm_type: {norm_type}") |
|
|
|
if use_conv: |
|
conv = nn.Conv2d( |
|
self.channels, |
|
self.out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
) |
|
else: |
|
assert self.channels == self.out_channels |
|
conv = nn.AvgPool2d(kernel_size=stride, stride=stride) |
|
|
|
|
|
if name == "conv": |
|
self.Conv2d_0 = conv |
|
self.conv = conv |
|
elif name == "Conv2d_0": |
|
self.conv = conv |
|
else: |
|
self.conv = conv |
|
|
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
|
|
|
assert hidden_states.shape[1] == self.channels |
|
|
|
if self.norm is not None: |
|
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute( |
|
0, 3, 1, 2 |
|
) |
|
|
|
if self.use_conv and self.padding == 0: |
|
pad = (0, 1, 0, 1) |
|
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) |
|
|
|
assert hidden_states.shape[1] == self.channels |
|
|
|
hidden_states = self.conv(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class Upsample2D(nn.Module): |
|
"""A 2D upsampling layer with an optional convolution. |
|
|
|
Parameters: |
|
channels (`int`): |
|
number of channels in the inputs and outputs. |
|
use_conv (`bool`, default `False`): |
|
option to use a convolution. |
|
use_conv_transpose (`bool`, default `False`): |
|
option to use a convolution transpose. |
|
out_channels (`int`, optional): |
|
number of output channels. Defaults to `channels`. |
|
name (`str`, default `conv`): |
|
name of the upsampling 2D layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
channels: int, |
|
use_conv: bool = False, |
|
use_conv_transpose: bool = False, |
|
out_channels: Optional[int] = None, |
|
name: str = "conv", |
|
kernel_size: Optional[int] = None, |
|
padding=1, |
|
stride=2, |
|
norm_type=None, |
|
eps=None, |
|
elementwise_affine=None, |
|
bias=True, |
|
interpolate=True, |
|
): |
|
super().__init__() |
|
self.channels = channels |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
self.use_conv_transpose = use_conv_transpose |
|
self.name = name |
|
self.interpolate = interpolate |
|
self.stride = stride |
|
|
|
if norm_type == "ln_norm": |
|
self.norm = nn.LayerNorm(channels, eps, elementwise_affine) |
|
elif norm_type == "rms_norm": |
|
self.norm = RMSNorm(channels, eps) |
|
elif norm_type is None: |
|
self.norm = None |
|
else: |
|
raise ValueError(f"unknown norm_type: {norm_type}") |
|
|
|
conv = None |
|
if use_conv_transpose: |
|
if kernel_size is None: |
|
kernel_size = 4 |
|
conv = nn.ConvTranspose2d( |
|
channels, |
|
self.out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
) |
|
elif use_conv: |
|
if kernel_size is None: |
|
kernel_size = 3 |
|
conv = nn.Conv2d( |
|
self.channels, |
|
self.out_channels, |
|
kernel_size=kernel_size, |
|
padding=padding, |
|
bias=bias, |
|
) |
|
|
|
|
|
if name == "conv": |
|
self.conv = conv |
|
else: |
|
self.Conv2d_0 = conv |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
output_size: Optional[int] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
|
|
assert hidden_states.shape[1] == self.channels |
|
|
|
if self.norm is not None: |
|
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute( |
|
0, 3, 1, 2 |
|
) |
|
|
|
if self.use_conv_transpose: |
|
return self.conv(hidden_states) |
|
|
|
|
|
|
|
|
|
dtype = hidden_states.dtype |
|
if dtype == torch.bfloat16: |
|
hidden_states = hidden_states.to(torch.float32) |
|
|
|
|
|
if hidden_states.shape[0] >= 64: |
|
hidden_states = hidden_states.contiguous() |
|
|
|
|
|
|
|
if self.interpolate: |
|
if output_size is None: |
|
hidden_states = F.interpolate( |
|
hidden_states, scale_factor=self.stride, mode="nearest" |
|
) |
|
else: |
|
hidden_states = F.interpolate( |
|
hidden_states, size=output_size, mode="nearest" |
|
) |
|
|
|
|
|
if dtype == torch.bfloat16: |
|
hidden_states = hidden_states.to(dtype) |
|
|
|
|
|
if self.use_conv: |
|
if self.name == "conv": |
|
hidden_states = self.conv(hidden_states) |
|
else: |
|
hidden_states = self.Conv2d_0(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class ConvNextBlock(nn.Module): |
|
def __init__( |
|
self, |
|
channels, |
|
norm_eps, |
|
elementwise_affine, |
|
use_bias, |
|
hidden_dropout, |
|
hidden_size, |
|
res_ffn_factor: int = 4, |
|
): |
|
super().__init__() |
|
self.depthwise = nn.Conv2d( |
|
channels, |
|
channels, |
|
kernel_size=7, |
|
padding=3, |
|
groups=channels, |
|
bias=use_bias, |
|
) |
|
self.norm = RMSNorm(channels, norm_eps) |
|
self.channelwise_linear_1 = nn.Linear( |
|
channels, int(channels * res_ffn_factor), bias=use_bias |
|
) |
|
self.channelwise_act = nn.GELU() |
|
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) |
|
self.channelwise_linear_2 = nn.Linear( |
|
int(channels * res_ffn_factor), channels, bias=use_bias |
|
) |
|
self.channelwise_dropout = nn.Dropout(hidden_dropout) |
|
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) |
|
|
|
def forward(self, x, cond_embeds): |
|
x_res = x |
|
|
|
x = self.depthwise(x) |
|
|
|
x = x.permute(0, 2, 3, 1) |
|
x = self.norm(x) |
|
x = self.channelwise_linear_1(x) |
|
x = self.channelwise_act(x) |
|
x = self.channelwise_norm(x) |
|
x = self.channelwise_linear_2(x) |
|
x = self.channelwise_dropout(x) |
|
x = x.permute(0, 3, 1, 2) |
|
|
|
x = x + x_res |
|
|
|
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) |
|
|
|
x = torch.addcmul( |
|
shift[:, :, None, None], x, (1 + scale)[:, :, None, None], value=1 |
|
) |
|
|
|
return x |
|
|
|
|
|
class Patchify(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels, |
|
block_out_channels, |
|
patch_size, |
|
bias, |
|
elementwise_affine, |
|
eps, |
|
kernel_size=None, |
|
): |
|
super().__init__() |
|
if kernel_size is None: |
|
kernel_size = patch_size |
|
self.patch_conv = nn.Conv2d( |
|
in_channels, |
|
block_out_channels, |
|
kernel_size=kernel_size, |
|
stride=patch_size, |
|
bias=bias, |
|
) |
|
self.norm = RMSNorm(block_out_channels, eps) |
|
|
|
def forward(self, x): |
|
embeddings = self.patch_conv(x) |
|
embeddings = embeddings.permute(0, 2, 3, 1) |
|
embeddings = self.norm(embeddings) |
|
embeddings = embeddings.permute(0, 3, 1, 2) |
|
return embeddings |
|
|
|
|
|
class Unpatchify(nn.Module): |
|
def __init__( |
|
self, in_channels, out_channels, patch_size, bias, elementwise_affine, eps |
|
): |
|
super().__init__() |
|
self.norm = RMSNorm(in_channels, eps) |
|
self.unpatch_conv = nn.Conv2d( |
|
in_channels, |
|
out_channels * patch_size * patch_size, |
|
kernel_size=1, |
|
bias=bias, |
|
) |
|
self.pixel_shuffle = nn.PixelShuffle(patch_size) |
|
self.patch_size = patch_size |
|
|
|
def forward(self, x): |
|
|
|
x = x.permute(0, 2, 3, 1) |
|
x = self.norm(x) |
|
x = x.permute(0, 3, 1, 2) |
|
x = self.unpatch_conv(x) |
|
x = self.pixel_shuffle(x) |
|
return x |
|
|
|
|
|
class UVitBlock(nn.Module): |
|
def __init__( |
|
self, |
|
channels, |
|
out_channels, |
|
num_res_blocks, |
|
stride, |
|
hidden_size, |
|
hidden_dropout, |
|
elementwise_affine, |
|
norm_eps, |
|
use_bias, |
|
downsample: bool, |
|
upsample: bool, |
|
res_ffn_factor: int = 4, |
|
seq_len=None, |
|
concat_input=False, |
|
original_input_channels=None, |
|
use_zero=True, |
|
norm_type="RMS", |
|
): |
|
super().__init__() |
|
|
|
self.res_blocks = nn.ModuleList() |
|
for i in range(num_res_blocks): |
|
conv_block = ConvNextBlock( |
|
channels, |
|
norm_eps, |
|
elementwise_affine, |
|
use_bias, |
|
hidden_dropout, |
|
hidden_size, |
|
res_ffn_factor=res_ffn_factor, |
|
) |
|
|
|
self.res_blocks.append(conv_block) |
|
|
|
if downsample: |
|
self.downsample = Downsample2D( |
|
channels=channels, |
|
out_channels=out_channels, |
|
use_conv=True, |
|
name="Conv2d_0", |
|
kernel_size=3, |
|
padding=1, |
|
stride=stride, |
|
norm_type="rms_norm", |
|
eps=norm_eps, |
|
elementwise_affine=elementwise_affine, |
|
bias=use_bias, |
|
) |
|
else: |
|
self.downsample = None |
|
|
|
if upsample: |
|
self.upsample = Upsample2D( |
|
channels=channels, |
|
out_channels=out_channels, |
|
use_conv_transpose=False, |
|
use_conv=True, |
|
kernel_size=3, |
|
padding=1, |
|
stride=stride, |
|
name="conv", |
|
norm_type="rms_norm", |
|
eps=norm_eps, |
|
elementwise_affine=elementwise_affine, |
|
bias=use_bias, |
|
interpolate=True, |
|
) |
|
else: |
|
self.upsample = None |
|
|
|
def forward(self, x, emb, recompute=False): |
|
for res_block in self.res_blocks: |
|
x = res_block(x, emb) |
|
|
|
if self.downsample is not None: |
|
x = self.downsample(x) |
|
|
|
if self.upsample is not None: |
|
x = self.upsample(x) |
|
|
|
return x |
|
|
|
|
|
class ShallowUViTEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
input_channels=3, |
|
stride=4, |
|
kernel_size=7, |
|
padding=None, |
|
block_out_channels=(768,), |
|
layers_in_middle=2, |
|
hidden_size=2048, |
|
elementwise_affine=True, |
|
use_bias=True, |
|
norm_eps=1e-6, |
|
dropout=0.0, |
|
use_mid_block=True, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.time_proj = Timesteps( |
|
block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0 |
|
) |
|
self.time_embed = TimestepEmbedding( |
|
block_out_channels[0], hidden_size, sample_proj_bias=use_bias |
|
) |
|
|
|
if padding is None: |
|
padding = math.ceil(kernel_size - stride) |
|
self.in_conv = nn.Conv2d( |
|
in_channels=input_channels, |
|
out_channels=block_out_channels[0], |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
) |
|
if use_mid_block: |
|
self.mid_block = UVitBlock( |
|
block_out_channels[-1], |
|
block_out_channels[-1], |
|
num_res_blocks=layers_in_middle, |
|
hidden_size=hidden_size, |
|
hidden_dropout=dropout, |
|
elementwise_affine=elementwise_affine, |
|
norm_eps=norm_eps, |
|
use_bias=use_bias, |
|
downsample=False, |
|
upsample=False, |
|
stride=1, |
|
res_ffn_factor=4, |
|
) |
|
else: |
|
self.mid_block = None |
|
|
|
def get_num_extra_tensors(self): |
|
return 2 |
|
|
|
def forward(self, x, timesteps): |
|
|
|
bs = x.shape[0] |
|
dtype = x.dtype |
|
|
|
t_emb = self.time_proj(timesteps.flatten()).view(bs, -1).to(dtype) |
|
t_emb = self.time_embed(t_emb) |
|
x_emb = self.in_conv(x) |
|
|
|
if self.mid_block is not None: |
|
x_emb = self.mid_block(x_emb, t_emb) |
|
|
|
hs = [x_emb] |
|
return x_emb, t_emb, hs |
|
|
|
|
|
class ShallowUViTDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels=768, |
|
out_channels=3, |
|
block_out_channels: Tuple[int] = (768,), |
|
upsamples=2, |
|
layers_in_middle=2, |
|
hidden_size=2048, |
|
elementwise_affine=True, |
|
norm_eps=1e-6, |
|
use_bias=True, |
|
dropout=0.0, |
|
use_mid_block=True, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
if use_mid_block: |
|
self.mid_block = UVitBlock( |
|
in_channels + block_out_channels[-1], |
|
block_out_channels[ |
|
-1 |
|
], |
|
num_res_blocks=layers_in_middle, |
|
hidden_size=hidden_size, |
|
hidden_dropout=dropout, |
|
elementwise_affine=elementwise_affine, |
|
norm_eps=norm_eps, |
|
use_bias=use_bias, |
|
downsample=False, |
|
upsample=False, |
|
stride=1, |
|
res_ffn_factor=4, |
|
) |
|
else: |
|
self.mid_block = None |
|
self.out_convs = nn.ModuleList() |
|
for rank in range(upsamples): |
|
if rank == upsamples - 1: |
|
curr_out_channels = out_channels |
|
else: |
|
curr_out_channels = block_out_channels[-1] |
|
if rank == 0: |
|
curr_in_channels = block_out_channels[-1] + in_channels |
|
else: |
|
curr_in_channels = block_out_channels[-1] |
|
self.out_convs.append( |
|
Unpatchify( |
|
curr_in_channels, |
|
curr_out_channels, |
|
patch_size=2, |
|
bias=use_bias, |
|
elementwise_affine=elementwise_affine, |
|
eps=norm_eps, |
|
) |
|
) |
|
self.input_norm = RMSNorm(in_channels, norm_eps) |
|
|
|
def forward(self, x, hs, t_emb): |
|
|
|
x = x.permute(0, 2, 3, 1) |
|
x = self.input_norm(x) |
|
x = x.permute(0, 3, 1, 2) |
|
|
|
x = torch.cat([x, hs.pop()], dim=1) |
|
if self.mid_block is not None: |
|
x = self.mid_block(x, t_emb) |
|
for out_conv in self.out_convs: |
|
x = out_conv(x) |
|
assert len(hs) == 0 |
|
return x |
|
|