Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| import argparse | |
| import os | |
| import imageio | |
| import torch | |
| import numpy as np | |
| from einops import rearrange | |
| from torch import Tensor, nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torchvision import transforms | |
| from safetensors.torch import load_file | |
| import torch.utils.checkpoint as checkpoint | |
| from .conv import Conv | |
| from .multiscale_bsq import MultiScaleBSQ | |
| ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16} | |
| class Normalize(nn.Module): | |
| def __init__(self, in_channels, norm_type, norm_axis="spatial"): | |
| super().__init__() | |
| self.norm_axis = norm_axis | |
| assert norm_type in ['group', 'batch', "no"] | |
| if norm_type == 'group': | |
| if in_channels % 32 == 0: | |
| self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
| elif in_channels % 24 == 0: | |
| self.norm = nn.GroupNorm(num_groups=24, num_channels=in_channels, eps=1e-6, affine=True) | |
| else: | |
| raise NotImplementedError | |
| elif norm_type == 'batch': | |
| self.norm = nn.SyncBatchNorm(in_channels, track_running_stats=False) # Runtime Error: grad inplace if set track_running_stats to True | |
| elif norm_type == 'no': | |
| self.norm = nn.Identity() | |
| def forward(self, x): | |
| if self.norm_axis == "spatial": | |
| if x.ndim == 4: | |
| x = self.norm(x) | |
| else: | |
| B, C, T, H, W = x.shape | |
| x = rearrange(x, "B C T H W -> (B T) C H W") | |
| x = self.norm(x) | |
| x = rearrange(x, "(B T) C H W -> B C T H W", T=T) | |
| elif self.norm_axis == "spatial-temporal": | |
| x = self.norm(x) | |
| else: | |
| raise NotImplementedError | |
| return x | |
| def swish(x: Tensor) -> Tensor: | |
| try: | |
| return x * torch.sigmoid(x) | |
| except: | |
| device = x.device | |
| x = x.cpu().pin_memory() | |
| return (x*torch.sigmoid(x)).to(device=device) | |
| class AttnBlock(nn.Module): | |
| def __init__(self, in_channels, norm_type='group', cnn_param=None): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.norm = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]) | |
| self.q = Conv(in_channels, in_channels, kernel_size=1) | |
| self.k = Conv(in_channels, in_channels, kernel_size=1) | |
| self.v = Conv(in_channels, in_channels, kernel_size=1) | |
| self.proj_out = Conv(in_channels, in_channels, kernel_size=1) | |
| def attention(self, h_: Tensor) -> Tensor: | |
| B, _, T, _, _ = h_.shape | |
| h_ = self.norm(h_) | |
| h_ = rearrange(h_, "B C T H W -> (B T) C H W") # spatial attention only | |
| q = self.q(h_) | |
| k = self.k(h_) | |
| v = self.v(h_) | |
| b, c, h, w = q.shape | |
| q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() | |
| k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() | |
| v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() | |
| h_ = nn.functional.scaled_dot_product_attention(q, k, v) | |
| return rearrange(h_, "(b t) 1 (h w) c -> b c t h w", h=h, w=w, c=c, b=B, t=T) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x + self.proj_out(self.attention(x)) | |
| class ResnetBlock(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int, norm_type='group', cnn_param=None): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.norm1 = Normalize(in_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]) | |
| if cnn_param["res_conv_2d"] in ["half", "full"]: | |
| self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d") | |
| else: | |
| self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) | |
| self.norm2 = Normalize(out_channels, norm_type, norm_axis=cnn_param["cnn_norm_axis"]) | |
| if cnn_param["res_conv_2d"] in ["full"]: | |
| self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d") | |
| else: | |
| self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) | |
| if self.in_channels != self.out_channels: | |
| self.nin_shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| h = x | |
| h = self.norm1(h) | |
| h = swish(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = swish(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| x = self.nin_shortcut(x) | |
| return x + h | |
| class Downsample(nn.Module): | |
| def __init__(self, in_channels, cnn_type="2d", spatial_down=False, temporal_down=False): | |
| super().__init__() | |
| assert spatial_down == True | |
| if cnn_type == "2d": | |
| self.pad = (0,1,0,1) | |
| if cnn_type == "3d": | |
| self.pad = (0,1,0,1,0,0) # add padding to the right for h-axis and w-axis. No padding for t-axis | |
| # no asymmetric padding in torch conv, must do it ourselves | |
| self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=2, padding=0, cnn_type=cnn_type, temporal_down=temporal_down) | |
| def forward(self, x: Tensor): | |
| x = nn.functional.pad(x, self.pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| return x | |
| class Upsample(nn.Module): | |
| def __init__(self, in_channels, cnn_type="2d", spatial_up=False, temporal_up=False, use_pxsl=False): | |
| super().__init__() | |
| if cnn_type == "2d": | |
| self.scale_factor = 2 | |
| self.causal_offset = 0 | |
| else: | |
| assert spatial_up == True | |
| if temporal_up: | |
| self.scale_factor = (2,2,2) | |
| self.causal_offset = -1 | |
| else: | |
| self.scale_factor = (1,2,2) | |
| self.causal_offset = 0 | |
| self.use_pxsl = use_pxsl | |
| if self.use_pxsl: | |
| self.conv = Conv(in_channels, in_channels*4, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset) | |
| self.pxsl = nn.PixelShuffle(2) | |
| else: | |
| self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_type, causal_offset=self.causal_offset) | |
| def forward(self, x: Tensor): | |
| if self.use_pxsl: | |
| x = self.conv(x) | |
| x = self.pxsl(x) | |
| else: | |
| try: | |
| x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") | |
| except: | |
| # shard across channel | |
| _xs = [] | |
| for i in range(x.shape[1]): | |
| _x = F.interpolate(x[:,i:i+1,...], scale_factor=self.scale_factor, mode="nearest") | |
| _xs.append(_x) | |
| x = torch.cat(_xs, dim=1) | |
| x = self.conv(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| ch: int, | |
| ch_mult: list[int], | |
| num_res_blocks: int, | |
| z_channels: int, | |
| in_channels = 3, | |
| patch_size=8, temporal_patch_size=4, | |
| norm_type='group', cnn_param=None, | |
| use_checkpoint=False, | |
| use_vae=True, | |
| ): | |
| super().__init__() | |
| self.max_down = np.log2(patch_size) | |
| self.temporal_max_down = np.log2(temporal_patch_size) | |
| self.temporal_down_offset = self.max_down - self.temporal_max_down | |
| self.ch = ch | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.in_channels = in_channels | |
| self.cnn_param = cnn_param | |
| self.use_checkpoint = use_checkpoint | |
| # downsampling | |
| # self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1) | |
| # cnn_param["cnn_type"] = "2d" for images, cnn_param["cnn_type"] = "3d" for videos | |
| if cnn_param["conv_in_out_2d"] == "yes": # "yes" for video | |
| self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type="2d") | |
| else: | |
| self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) | |
| in_ch_mult = (1,) + tuple(ch_mult) | |
| self.in_ch_mult = in_ch_mult | |
| self.down = nn.ModuleList() | |
| block_in = self.ch | |
| for i_level in range(self.num_resolutions): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_in = ch * in_ch_mult[i_level] | |
| block_out = ch * ch_mult[i_level] | |
| for _ in range(self.num_res_blocks): | |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param)) | |
| block_in = block_out | |
| down = nn.Module() | |
| down.block = block | |
| down.attn = attn | |
| # downsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE | |
| spatial_down = True if i_level < self.max_down else False | |
| temporal_down = True if i_level < self.max_down and i_level >= self.temporal_down_offset else False | |
| if spatial_down or temporal_down: | |
| down.downsample = Downsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_down=spatial_down, temporal_down=temporal_down) | |
| self.down.append(down) | |
| # middle | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param) | |
| if cnn_param["cnn_attention"] == "yes": | |
| self.mid.attn_1 = AttnBlock(block_in, norm_type, cnn_param=cnn_param) | |
| self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param) | |
| # end | |
| self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"]) | |
| if cnn_param["conv_inner_2d"] == "yes": | |
| self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type="2d") | |
| else: | |
| self.conv_out = Conv(block_in, (int(use_vae) + 1) * z_channels, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) | |
| def forward(self, x, return_hidden=False): | |
| if not self.use_checkpoint: | |
| return self._forward(x, return_hidden=return_hidden) | |
| else: | |
| return checkpoint.checkpoint(self._forward, x, return_hidden, use_reentrant=False) | |
| def _forward(self, x: Tensor, return_hidden=False) -> Tensor: | |
| # downsampling | |
| h0 = self.conv_in(x) | |
| hs = [h0] | |
| for i_level in range(self.num_resolutions): | |
| for i_block in range(self.num_res_blocks): | |
| h = self.down[i_level].block[i_block](hs[-1]) | |
| if len(self.down[i_level].attn) > 0: | |
| h = self.down[i_level].attn[i_block](h) | |
| hs.append(h) | |
| if hasattr(self.down[i_level], "downsample"): | |
| hs.append(self.down[i_level].downsample(hs[-1])) | |
| # middle | |
| h = hs[-1] | |
| hs_mid = [h] | |
| h = self.mid.block_1(h) | |
| if self.cnn_param["cnn_attention"] == "yes": | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h) | |
| hs_mid.append(h) | |
| # end | |
| h = self.norm_out(h) | |
| h = swish(h) | |
| h = self.conv_out(h) | |
| if return_hidden: | |
| return h, hs, hs_mid | |
| else: | |
| return h | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| ch: int, | |
| ch_mult: list[int], | |
| num_res_blocks: int, | |
| z_channels: int, | |
| out_ch = 3, | |
| patch_size=8, temporal_patch_size=4, | |
| norm_type="group", cnn_param=None, | |
| use_checkpoint=False, | |
| use_freq_dec=False, # use frequency features for decoder | |
| use_pxsf=False | |
| ): | |
| super().__init__() | |
| self.max_up = np.log2(patch_size) | |
| self.temporal_max_up = np.log2(temporal_patch_size) | |
| self.temporal_up_offset = self.max_up - self.temporal_max_up | |
| self.ch = ch | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| self.ffactor = 2 ** (self.num_resolutions - 1) | |
| self.cnn_param = cnn_param | |
| self.use_checkpoint = use_checkpoint | |
| self.use_freq_dec = use_freq_dec | |
| self.use_pxsf = use_pxsf | |
| # compute in_ch_mult, block_in and curr_res at lowest res | |
| block_in = ch * ch_mult[self.num_resolutions - 1] | |
| # z to block_in | |
| if cnn_param["conv_inner_2d"] == "yes": | |
| self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type="2d") | |
| else: | |
| self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) | |
| # middle | |
| self.mid = nn.Module() | |
| self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param) | |
| if cnn_param["cnn_attention"] == "yes": | |
| self.mid.attn_1 = AttnBlock(block_in, norm_type=norm_type, cnn_param=cnn_param) | |
| self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type, cnn_param=cnn_param) | |
| # upsampling | |
| self.up = nn.ModuleList() | |
| for i_level in reversed(range(self.num_resolutions)): | |
| block = nn.ModuleList() | |
| attn = nn.ModuleList() | |
| block_out = ch * ch_mult[i_level] | |
| for _ in range(self.num_res_blocks + 1): | |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type, cnn_param=cnn_param)) | |
| block_in = block_out | |
| up = nn.Module() | |
| up.block = block | |
| up.attn = attn | |
| # upsample, stride=1, stride=2, stride=2 for 4x8x8 Video VAE, offset 1 compared with encoder | |
| # https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228 | |
| spatial_up = True if 1 <= i_level <= self.max_up else False | |
| temporal_up = True if 1 <= i_level <= self.max_up and i_level >= self.temporal_up_offset+1 else False | |
| if spatial_up or temporal_up: | |
| up.upsample = Upsample(block_in, cnn_type=cnn_param["cnn_type"], spatial_up=spatial_up, temporal_up=temporal_up, use_pxsl=self.use_pxsf) | |
| self.up.insert(0, up) # prepend to get consistent order | |
| # end | |
| self.norm_out = Normalize(block_in, norm_type, norm_axis=cnn_param["cnn_norm_axis"]) | |
| if cnn_param["conv_in_out_2d"] == "yes": | |
| self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type="2d") | |
| else: | |
| self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1, cnn_type=cnn_param["cnn_type"]) | |
| def forward(self, z): | |
| if not self.use_checkpoint: | |
| return self._forward(z) | |
| else: | |
| return checkpoint.checkpoint(self._forward, z, use_reentrant=False) | |
| def _forward(self, z: Tensor) -> Tensor: | |
| # z to block_in | |
| h = self.conv_in(z) | |
| # middle | |
| h = self.mid.block_1(h) | |
| if self.cnn_param["cnn_attention"] == "yes": | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h) | |
| # upsampling | |
| for i_level in reversed(range(self.num_resolutions)): | |
| for i_block in range(self.num_res_blocks + 1): | |
| h = self.up[i_level].block[i_block](h) | |
| if len(self.up[i_level].attn) > 0: | |
| h = self.up[i_level].attn[i_block](h) | |
| if hasattr(self.up[i_level], "upsample"): | |
| h = self.up[i_level].upsample(h) | |
| # end | |
| h = self.norm_out(h) | |
| h = swish(h) | |
| h = self.conv_out(h) | |
| return h | |
| class AutoEncoder(nn.Module): | |
| def __init__(self, args): | |
| super().__init__() | |
| self.args = args | |
| cnn_param = dict( | |
| cnn_type=args.cnn_type, | |
| conv_in_out_2d=args.conv_in_out_2d, | |
| res_conv_2d=args.res_conv_2d, | |
| cnn_attention=args.cnn_attention, | |
| cnn_norm_axis=args.cnn_norm_axis, | |
| conv_inner_2d=args.conv_inner_2d, | |
| ) | |
| self.encoder = Encoder( | |
| ch=args.base_ch, | |
| ch_mult=args.encoder_ch_mult, | |
| num_res_blocks=args.num_res_blocks, | |
| z_channels=args.codebook_dim, | |
| patch_size=args.patch_size, | |
| temporal_patch_size=args.temporal_patch_size, | |
| cnn_param=cnn_param, | |
| use_checkpoint=args.use_checkpoint, | |
| use_vae=args.use_vae, | |
| ) | |
| self.decoder = Decoder( | |
| ch=args.base_ch, | |
| ch_mult=args.decoder_ch_mult, | |
| num_res_blocks=args.num_res_blocks, | |
| z_channels=args.codebook_dim, | |
| patch_size=args.patch_size, | |
| temporal_patch_size=args.temporal_patch_size, | |
| cnn_param=cnn_param, | |
| use_checkpoint=args.use_checkpoint, | |
| use_freq_dec=args.use_freq_dec, | |
| use_pxsf=args.use_pxsf # pixelshuffle for upsampling | |
| ) | |
| self.z_drop = nn.Dropout(args.z_drop) | |
| self.scale_factor = 0.3611 | |
| self.shift_factor = 0.1159 | |
| self.codebook_dim = self.embed_dim = args.codebook_dim | |
| self.gan_feat_weight = args.gan_feat_weight | |
| self.video_perceptual_weight = args.video_perceptual_weight | |
| self.recon_loss_type = args.recon_loss_type | |
| self.l1_weight = args.l1_weight | |
| self.use_vae = args.use_vae | |
| self.kl_weight = args.kl_weight | |
| self.lfq_weight = args.lfq_weight | |
| self.image_gan_weight = args.image_gan_weight # image GAN loss weight | |
| self.video_gan_weight = args.video_gan_weight # video GAN loss weight | |
| self.perceptual_weight = args.perceptual_weight | |
| self.flux_weight = args.flux_weight | |
| self.cycle_weight = args.cycle_weight | |
| self.cycle_feat_weight = args.cycle_feat_weight | |
| self.cycle_gan_weight = args.cycle_gan_weight | |
| self.flux_image_encoder = None | |
| if not args.use_vae: | |
| if args.quantizer_type == 'MultiScaleBSQ': | |
| self.quantizer = MultiScaleBSQ( | |
| dim = args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined | |
| codebook_size = args.codebook_size, # codebook size, must be a power of 2 | |
| entropy_loss_weight = args.entropy_loss_weight, # how much weight to place on entropy loss | |
| diversity_gamma = args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894 | |
| preserve_norm=args.preserve_norm, # preserve norm of the input for BSQ | |
| ln_before_quant=args.ln_before_quant, # use layer norm before quantization | |
| ln_init_by_sqrt=args.ln_init_by_sqrt, # layer norm init value 1/sqrt(d) | |
| commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss | |
| new_quant=args.new_quant, | |
| use_decay_factor=args.use_decay_factor, | |
| mask_out=args.mask_out, | |
| use_stochastic_depth=args.use_stochastic_depth, | |
| drop_rate=args.drop_rate, | |
| schedule_mode=args.schedule_mode, | |
| keep_first_quant=args.keep_first_quant, | |
| keep_last_quant=args.keep_last_quant, | |
| remove_residual_detach=args.remove_residual_detach, | |
| use_out_phi=args.use_out_phi, | |
| use_out_phi_res=args.use_out_phi_res, | |
| random_flip = args.random_flip, | |
| flip_prob = args.flip_prob, | |
| flip_mode = args.flip_mode, | |
| max_flip_lvl = args.max_flip_lvl, | |
| random_flip_1lvl = args.random_flip_1lvl, | |
| flip_lvl_idx = args.flip_lvl_idx, | |
| drop_when_test = args.drop_when_test, | |
| drop_lvl_idx = args.drop_lvl_idx, | |
| drop_lvl_num = args.drop_lvl_num, | |
| ) | |
| self.quantize = self.quantizer | |
| self.vocab_size = args.codebook_size | |
| else: | |
| raise NotImplementedError(f"{args.quantizer_type} not supported") | |
| def forward(self, x): | |
| is_image = x.ndim == 4 | |
| if not is_image: | |
| B, C, T, H, W = x.shape | |
| else: | |
| B, C, H, W = x.shape | |
| T = 1 | |
| enc_dtype = ptdtype[self.args.encoder_dtype] | |
| with torch.amp.autocast("cuda", dtype=enc_dtype): | |
| h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W | |
| hs = [_h.detach() for _h in hs] | |
| hs_mid = [_h.detach() for _h in hs_mid] | |
| h = h.to(dtype=torch.float32) | |
| # print(z.shape) | |
| # Multiscale LFQ | |
| z, all_indices, all_loss = self.quantizer(h) | |
| x_recon = self.decoder(z) | |
| vq_output = { | |
| "commitment_loss": torch.mean(all_loss) * self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty | |
| "encodings": all_indices, | |
| } | |
| return x_recon, vq_output | |
| def encode_for_raw_features(self, x, scale_schedule, return_residual_norm_per_scale=False): | |
| is_image = x.ndim == 4 | |
| if not is_image: | |
| B, C, T, H, W = x.shape | |
| else: | |
| B, C, H, W = x.shape | |
| T = 1 | |
| enc_dtype = ptdtype[self.args.encoder_dtype] | |
| with torch.amp.autocast("cuda", dtype=enc_dtype): | |
| h, hs, hs_mid = self.encoder(x, return_hidden=True) # B C H W or B C T H W | |
| hs = [_h.detach() for _h in hs] | |
| hs_mid = [_h.detach() for _h in hs_mid] | |
| h = h.to(dtype=torch.float32) | |
| return h, hs, hs_mid | |
| def encode(self, x, scale_schedule, return_residual_norm_per_scale=False): | |
| h, hs, hs_mid = self.encode_for_raw_features(x, scale_schedule, return_residual_norm_per_scale) | |
| # Multiscale LFQ | |
| z, all_indices, all_bit_indices, residual_norm_per_scale, all_loss, var_input = self.quantizer(h, scale_schedule=scale_schedule, return_residual_norm_per_scale=return_residual_norm_per_scale) | |
| return h, z, all_indices, all_bit_indices, residual_norm_per_scale, var_input | |
| def decode(self, z): | |
| x_recon = self.decoder(z) | |
| x_recon = torch.clamp(x_recon, min=-1, max=1) | |
| return x_recon | |
| def decode_from_indices(self, all_indices, scale_schedule, label_type): | |
| summed_codes = 0 | |
| for idx_Bl in all_indices: | |
| codes = self.quantizer.lfq.indices_to_codes(idx_Bl, label_type) | |
| summed_codes += F.interpolate(codes, size=scale_schedule[-1], mode=self.quantizer.z_interplote_up) | |
| assert summed_codes.shape[-3] == 1 | |
| x_recon = self.decoder(summed_codes.squeeze(-3)) | |
| x_recon = torch.clamp(x_recon, min=-1, max=1) | |
| return summed_codes, x_recon | |
| def add_model_specific_args(parent_parser): | |
| parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) | |
| parser.add_argument("--flux_weight", type=float, default=0) | |
| parser.add_argument("--cycle_weight", type=float, default=0) | |
| parser.add_argument("--cycle_feat_weight", type=float, default=0) | |
| parser.add_argument("--cycle_gan_weight", type=float, default=0) | |
| parser.add_argument("--cycle_loop", type=int, default=0) | |
| parser.add_argument("--z_drop", type=float, default=0.) | |
| return parser | |
