import argparse import os import imageio import torch import numpy as np from einops import rearrange from torch import Tensor, nn import torch.nn.functional as F import torchvision from torchvision import transforms from safetensors.torch import load_file import torch.utils.checkpoint as checkpoint from .conv import Conv from .multiscale_bsq import MultiScaleBSQ ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16} class Normalize(nn.Module): def __init__(self, in_channels, norm_type, norm_axis="spatial"): super().__init__() self.norm_axis = norm_axis assert norm_type in ['group', 'batch', "no"] if norm_type == 'group': if in_channels % 32 == 0: self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) elif in_channels % 24 == 0: self.norm = nn.GroupNorm(num_groups=24, num_channels=in_channels, eps=1e-6, affine=True) else: raise NotImplementedError elif norm_type == 'batch': self.norm = nn.SyncBatchNorm(in_channels, track_running_stats=False) # Runtime Error: grad inplace if set track_running_stats to True elif norm_type == 'no': self.norm = nn.Identity() def forward(self, x): if self.norm_axis == "spatial": if x.ndim == 4: x = self.norm(x) else: B, C, T, H, W = x.shape x = rearrange(x, "B C T H W -> (B T) C H W") x = self.norm(x) x = rearrange(x, "(B T) C H W -> B C T H W", T=T) elif self.norm_axis == "spatial-temporal": x = self.norm(x) else: raise NotImplementedError return x def swish(x: Tensor) -> Tensor: try: return x * torch.sigmoid(x) except: device = x.device x = x.cpu().pin_memory() return (x*torch.sigmoid(x)).to(device=device) class AttnBlock(nn.Module): def __init__(self, in_channels, norm_type='group', cnn_param=None): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]) self.q = Conv(in_channels, in_channels, kernel_size=1) self.k = Conv(in_channels, in_channels, kernel_size=1) self.v = Conv(in_channels, in_channels, kernel_size=1) self.proj_out = Conv(in_channels, in_channels, kernel_size=1) def attention(self, h_: Tensor) -> Tensor: B, _, T, _, _ = h_.shape h_ = self.norm(h_) h_ = rearrange(h_, "B C T H W -> (B T) C H W") # spatial attention only q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() h_ = nn.functional.scaled_dot_product_attention(q, k, v) return rearrange(h_, "(b t) 1 (h w) c -> b c t h w", h=h, w=w, c=c, b=B, t=T) def forward(self, x: Tensor) -> Tensor: return x + self.proj_out(self.attention(x)) class ResnetBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, norm_type='group', cnn_param=None): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.norm1 = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]) if cnn_param["res_conv_2d"] in ["half", "full"]: self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d") else: self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) self.norm2 = Normalize(out_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]) if cnn_param["res_conv_2d"] in ["full"]: self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d") else: self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) if self.in_channels != self.out_channels: self.nin_shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h = x h = self.norm1(h) h = swish(h) h = self.conv1(h) h = self.norm2(h) h = swish(h) h = self.conv2(h) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) return x + h class Downsample(nn.Module): def __init__(self, in_channels, cnn_type="2d", spatial_down=False, temporal_down=False): super().__init__() assert spatial_down == True if cnn_type == "2d": self.pad = (0,1,0,1) if cnn_type == "3d": self.pad = (0,1,0,1,0,0) # add padding to the right for h-axis and w-axis. No padding for t-axis # no asymmetric padding in torch conv, must do it ourselves self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=2, padding=0, cnn_type=cnn_type, temporal_down=temporal_down) def forward(self, x: Tensor): x = nn.functional.pad(x, self.pad, mode="constant", value=0) x = self.conv(x) return x class Upsample(nn.Module): def __init__(self, in_channels, cnn_type="2d", spatial_up=False, temporal_up=False, use_pxsl=False): super().__init__() if cnn_type == "2d": self.scale_factor = 2 self.causal_offset = 0 else: assert spatial_up == True if temporal_up: self.scale_factor = (2,2,2) self.causal_offset = -1 else: self.scale_factor = (1,2,2) self.causal_offset = 0 self.use_pxsl = use_pxsl if self.use_pxsl: self.conv = Conv(in_channels, in_channels*4, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset) self.pxsl = nn.PixelShuffle(2) else: self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset) def forward(self, x: Tensor): if self.use_pxsl: x = self.conv(x) x = self.pxsl(x) else: try: x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") except: # shard across channel _xs = [] for i in range(x.shape[1]): _x = F.interpolate(x[:,i:i+1,...], scale_factor=self.scale_factor, mode="nearest") _xs.append(_x) x = torch.cat(_xs, dim=1) x = self.conv(x) return x class Encoder(nn.Module): def __init__( self, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int, in_channels = 3, patch_size=8, temporal_patch_size=4, norm_type='group', cnn_param=None, use_checkpoint=False, use_vae=True, ): super().__init__() self.max_down = np.log2(patch_size) self.temporal_max_down = np.log2(temporal_patch_size) self.temporal_down_offset = self.max_down - self.temporal_max_down self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.in_channels = in_channels self.cnn_param = cnn_param self.use_checkpoint = use_checkpoint # downsampling # self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1) # cnn_param["cnn_type"] = "2d" for images, cnn_param["cnn_type"] = "3d" for videos if cnn_param["conv_in_out_2d"] == "yes": # "yes" for video self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type="2d") else: self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() block_in = self.ch for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param)) block_in = block_out down = nn.Module() down.block = block down.attn = attn # downsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE spatial_down = True if i_level < self.max_down else False temporal_down = True if i_level < self.max_down and i_level >= self.temporal_down_offset else False if spatial_down or temporal_down: down.downsample = Downsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_down=spatial_down, temporal_down=temporal_down) self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param) if cnn_param["cnn_attention"] == "yes": self.mid.attn_1 = AttnBlock(block_in, norm_type, cnn_param=cnn_param) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param) # end self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"]) if cnn_param["conv_inner_2d"] == "yes": self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d") else: self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) def forward(self, x, return_hidden=False): if not self.use_checkpoint: return self._forward(x, return_hidden=return_hidden) else: return checkpoint.checkpoint(self._forward, x, return_hidden, use_reentrant=False) def _forward(self, x: Tensor, return_hidden=False) -> Tensor: # downsampling h0 = self.conv_in(x) hs = [h0] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1]) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if hasattr(self.down[i_level], "downsample"): hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] hs_mid = [h] h = self.mid.block_1(h) if self.cnn_param["cnn_attention"] == "yes": h = self.mid.attn_1(h) h = self.mid.block_2(h) hs_mid.append(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) if return_hidden: return h, hs, hs_mid else: return h class Decoder(nn.Module): def __init__( self, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int, out_ch = 3, patch_size=8, temporal_patch_size=4, norm_type="group", cnn_param=None, use_checkpoint=False, use_freq_dec=False, # use frequency features for decoder use_pxsf=False ): super().__init__() self.max_up = np.log2(patch_size) self.temporal_max_up = np.log2(temporal_patch_size) self.temporal_up_offset = self.max_up - self.temporal_max_up self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.ffactor = 2 ** (self.num_resolutions - 1) self.cnn_param = cnn_param self.use_checkpoint = use_checkpoint self.use_freq_dec = use_freq_dec self.use_pxsf = use_pxsf # compute in_ch_mult, block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] # z to block_in if cnn_param["conv_inner_2d"] == "yes": self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type="2d") else: self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param) if cnn_param["cnn_attention"] == "yes": self.mid.attn_1 = AttnBlock(block_in, norm_type=norm_type, cnn_param=cnn_param) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param)) block_in = block_out up = nn.Module() up.block = block up.attn = attn # upsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE, offset 1 compared with encoder # https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228 spatial_up = True if 1 <= i_level <= self.max_up else False temporal_up = True if 1 <= i_level <= self.max_up and i_level >= self.temporal_up_offset+1 else False if spatial_up or temporal_up: up.upsample = Upsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_up=spatial_up, temporal_up=temporal_up, use_pxsl=self.use_pxsf) self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"]) if cnn_param["conv_in_out_2d"] == "yes": self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type="2d") else: self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) def forward(self, z): if not self.use_checkpoint: return self._forward(z) else: return checkpoint.checkpoint(self._forward, z, use_reentrant=False) def _forward(self, z: Tensor) -> Tensor: # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h) if self.cnn_param["cnn_attention"] == "yes": h = self.mid.attn_1(h) h = self.mid.block_2(h) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if hasattr(self.up[i_level], "upsample"): h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class AutoEncoder(nn.Module): def __init__(self, args): super().__init__() self.args = args cnn_param = dict( cnn_type=args.cnn_type, conv_in_out_2d=args.conv_in_out_2d, res_conv_2d=args.res_conv_2d, cnn_attention=args.cnn_attention, cnn_norm_axis=args.cnn_norm_axis, conv_inner_2d=args.conv_inner_2d, ) self.encoder = Encoder( ch=args.base_ch, ch_mult=args.encoder_ch_mult, num_res_blocks=args.num_res_blocks, z_channels=args.codebook_dim, patch_size=args.patch_size, temporal_patch_size=args.temporal_patch_size, cnn_param=cnn_param, use_checkpoint=args.use_checkpoint, use_vae=args.use_vae, ) self.decoder = Decoder( ch=args.base_ch, ch_mult=args.decoder_ch_mult, num_res_blocks=args.num_res_blocks, z_channels=args.codebook_dim, patch_size=args.patch_size, temporal_patch_size=args.temporal_patch_size, cnn_param=cnn_param, use_checkpoint=args.use_checkpoint, use_freq_dec=args.use_freq_dec, use_pxsf=args.use_pxsf # pixelshuffle for upsampling ) self.z_drop = nn.Dropout(args.z_drop) self.scale_factor = 0.3611 self.shift_factor = 0.1159 self.codebook_dim = self.embed_dim = args.codebook_dim self.gan_feat_weight = args.gan_feat_weight self.video_perceptual_weight = args.video_perceptual_weight self.recon_loss_type = args.recon_loss_type self.l1_weight = args.l1_weight self.use_vae = args.use_vae self.kl_weight = args.kl_weight self.lfq_weight = args.lfq_weight self.image_gan_weight = args.image_gan_weight # image GAN loss weight self.video_gan_weight = args.video_gan_weight # video GAN loss weight self.perceptual_weight = args.perceptual_weight self.flux_weight = args.flux_weight self.cycle_weight = args.cycle_weight self.cycle_feat_weight = args.cycle_feat_weight self.cycle_gan_weight = args.cycle_gan_weight self.flux_image_encoder = None if not args.use_vae: if args.quantizer_type == 'MultiScaleBSQ': self.quantizer = MultiScaleBSQ( dim = args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined codebook_size = args.codebook_size, # codebook size, must be a power of 2 entropy_loss_weight = args.entropy_loss_weight, # how much weight to place on entropy loss diversity_gamma = args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894 preserve_norm=args.preserve_norm, # preserve norm of the input for BSQ ln_before_quant=args.ln_before_quant, # use layer norm before quantization ln_init_by_sqrt=args.ln_init_by_sqrt, # layer norm init value 1/sqrt(d) commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss new_quant=args.new_quant, use_decay_factor=args.use_decay_factor, mask_out=args.mask_out, use_stochastic_depth=args.use_stochastic_depth, drop_rate=args.drop_rate, schedule_mode=args.schedule_mode, keep_first_quant=args.keep_first_quant, keep_last_quant=args.keep_last_quant, remove_residual_detach=args.remove_residual_detach, use_out_phi=args.use_out_phi, use_out_phi_res=args.use_out_phi_res, random_flip = args.random_flip, flip_prob = args.flip_prob, flip_mode = args.flip_mode, max_flip_lvl = args.max_flip_lvl, random_flip_1lvl = args.random_flip_1lvl, flip_lvl_idx = args.flip_lvl_idx, drop_when_test = args.drop_when_test, drop_lvl_idx = args.drop_lvl_idx, drop_lvl_num = args.drop_lvl_num, ) self.quantize = self.quantizer self.vocab_size = args.codebook_size else: raise NotImplementedError(f"{args.quantizer_type} not supported") def forward(self, x): is_image = x.ndim == 4 if not is_image: B, C, T, H, W = x.shape else: B, C, H, W = x.shape T = 1 enc_dtype = ptdtype[self.args.encoder_dtype] with torch.amp.autocast("cuda", dtype=enc_dtype): h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W hs = [_h.detach() for _h in hs] hs_mid = [_h.detach() for _h in hs_mid] h = h.to(dtype=torch.float32) # print(z.shape) # Multiscale LFQ z, all_indices, all_loss = self.quantizer(h) x_recon = self.decoder(z) vq_output = { "commitment_loss": torch.mean(all_loss) * self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty "encodings": all_indices, } return x_recon, vq_output def encode_for_raw_features(self, x, scale_schedule, return_residual_norm_per_scale=False): is_image = x.ndim == 4 if not is_image: B, C, T, H, W = x.shape else: B, C, H, W = x.shape T = 1 enc_dtype = ptdtype[self.args.encoder_dtype] with torch.amp.autocast("cuda", dtype=enc_dtype): h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W hs = [_h.detach() for _h in hs] hs_mid = [_h.detach() for _h in hs_mid] h = h.to(dtype=torch.float32) return h, hs, hs_mid def encode(self, x, scale_schedule, return_residual_norm_per_scale=False): h, hs, hs_mid = self.encode_for_raw_features(x, scale_schedule, return_residual_norm_per_scale) # Multiscale LFQ z, all_indices, all_bit_indices, residual_norm_per_scale, all_loss, var_input = self.quantizer(h, scale_schedule=scale_schedule, return_residual_norm_per_scale=return_residual_norm_per_scale) return h, z, all_indices, all_bit_indices, residual_norm_per_scale, var_input def decode(self, z): x_recon = self.decoder(z) x_recon = torch.clamp(x_recon, min=-1, max=1) return x_recon def decode_from_indices(self, all_indices, scale_schedule, label_type): summed_codes = 0 for idx_Bl in all_indices: codes = self.quantizer.lfq.indices_to_codes(idx_Bl, label_type) summed_codes += F.interpolate(codes, size=scale_schedule[-1], mode=self.quantizer.z_interplote_up) assert summed_codes.shape[-3] == 1 x_recon = self.decoder(summed_codes.squeeze(-3)) x_recon = torch.clamp(x_recon, min=-1, max=1) return summed_codes, x_recon @staticmethod def add_model_specific_args(parent_parser): parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument("--flux_weight", type=float, default=0) parser.add_argument("--cycle_weight", type=float, default=0) parser.add_argument("--cycle_feat_weight", type=float, default=0) parser.add_argument("--cycle_gan_weight", type=float, default=0) parser.add_argument("--cycle_loop", type=int, default=0) parser.add_argument("--z_drop", type=float, default=0.) return parser