# What is missing from this implementation # 1. Global context in res block # 2. Cross attention of conditional information in resnet block # from functools import partial import tops from tops.config import instantiate import warnings from typing import Iterable, List, Tuple import numpy as np import torch import torch.nn as nn from torch import einsum from einops import rearrange from dp2 import infer, utils from .base import BaseGenerator from sg3_torch_utils.ops import bias_act from dp2.layers import Sequential import torch.nn.functional as F from torchvision.transforms.functional import resize, InterpolationMode from sg3_torch_utils.ops import conv2d_resample, fma, upfirdn2d class Upfirdn2d(torch.nn.Module): def __init__(self, down=1, up=1, fix_gain=True): super().__init__() self.register_buffer("resample_filter", upfirdn2d.setup_filter([1, 3, 3, 1])) fw, fh = upfirdn2d._get_filter_size(self.resample_filter) px0, px1, py0, py1 = upfirdn2d._parse_padding(0) self.down = down self.up = up if up > 1: px0 += (fw + up - 1) // 2 px1 += (fw - up) // 2 py0 += (fh + up - 1) // 2 py1 += (fh - up) // 2 if down > 1: px0 += (fw - down + 1) // 2 px1 += (fw - down) // 2 py0 += (fh - down + 1) // 2 py1 += (fh - down) // 2 self.padding = [px0,px1,py0,py1] self.gain = up**2 if fix_gain else 1 def forward(self, x, *args): if isinstance(x, dict): x = {k: v for k, v in x.items()} x["x"] = upfirdn2d.upfirdn2d(x["x"], self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain) return x x = upfirdn2d.upfirdn2d(x, self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain) if len(args) == 0: return x return (x, *args) @torch.no_grad() def spatial_embed_keypoints(keypoints: torch.Tensor, x): tops.assert_shape(keypoints, (None, None, 3)) B, N_K, _ = keypoints.shape H, W = x.shape[-2:] keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32) x, y, visible = keypoints.chunk(3, dim=2) x = (x * W).round().long().clamp(0, W-1) y = (y * H).round().long().clamp(0, H-1) kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1) pos = (kp_idx*(H*W) + y*W + x + 1) # Offset all by 1 to index invisible keypoints as 0 pos = (pos * visible.round().long()).squeeze(dim=-1) keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32) keypoint_spatial.scatter_(1, pos, 1) keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W) return keypoint_spatial def modulated_conv2d( x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. styles, # Modulation coefficients of shape [batch_size, in_channels]. noise = None, # Optional noise tensor to add to the output activations. up = 1, # Integer upsampling factor. down = 1, # Integer downsampling factor. padding = 0, # Padding with respect to the upsampled image. resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). demodulate = True, # Apply weight demodulation? flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? ): batch_size = x.shape[0] out_channels, in_channels, kh, kw = weight.shape tops.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] tops.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] tops.assert_shape(styles, [batch_size, in_channels]) # [NI] # Pre-normalize inputs to avoid FP16 overflow. if x.dtype == torch.float16 and demodulate: weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I # Calculate per-sample weights and demodulation coefficients. w = None dcoefs = None if demodulate or fused_modconv: w = weight.unsqueeze(0) # [NOIkk] w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] if demodulate: dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] if demodulate and fused_modconv: w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] # Execute by scaling the activations before and after the convolution. if not fused_modconv: x = x * styles.reshape(batch_size, -1, 1, 1) x = conv2d_resample.conv2d_resample(x=x, w=weight, f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) if demodulate and noise is not None: x = fma.fma(x, dcoefs.reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) elif demodulate: x = x * dcoefs.reshape(batch_size, -1, 1, 1) elif noise is not None: x = x.add_(noise.to(x.dtype)) return x with tops.suppress_tracer_warnings(): # this value will be treated as a constant batch_size = int(batch_size) # Execute as one fused op using grouped convolution. tops.assert_shape(x, [batch_size, in_channels, None, None]) x = x.reshape(1, -1, *x.shape[2:]) w = w.reshape(-1, in_channels, kh, kw) x = conv2d_resample.conv2d_resample(x=x, w=w, f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) x = x.reshape(batch_size, -1, *x.shape[2:]) if noise is not None: x = x.add_(noise) return x class Identity(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, *args, **kwargs): return x class LayerNorm(nn.Module): def __init__(self, dim, stable=False): super().__init__() self.stable = stable self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): if self.stable: x = x / x.amax(dim=-1, keepdim=True).detach() eps = 1e-5 if x.dtype == torch.float32 else 1e-3 var = torch.var(x, dim=-1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=-1, keepdim=True) return (x - mean) * (var + eps).rsqrt() * self.g class FullyConnectedLayer(torch.nn.Module): def __init__(self, in_features, # Number of input features. out_features, # Number of output features. bias = True, # Apply additive bias before the activation function? activation = 'linear', # Activation function: 'relu', 'lrelu', etc. lr_multiplier = 1, # Learning rate multiplier. bias_init = 0, # Initial value for the additive bias. ): super().__init__() self.repr = dict( in_features=in_features, out_features=out_features, bias=bias, activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init) self.activation = activation self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier self.in_features = in_features self.out_features = out_features def forward(self, x): w = self.weight * self.weight_gain b = self.bias if b is not None: if self.bias_gain != 1: b = b * self.bias_gain x = F.linear(x, w) x = bias_act.bias_act(x, b, act=self.activation) return x def extra_repr(self) -> str: return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) def checkpoint_fn(fn, *args, **kwargs): warnings.simplefilter("ignore") return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs) class Conv2d(torch.nn.Module): def __init__( self, in_channels, out_channels, kernel_size=3, activation='lrelu', conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. bias=True, norm=None, lr_multiplier=1, bias_init=0, w_dim=None, gradient_checkpoint_norm=False, gain=1, ): super().__init__() self.fused_modconv = False if norm == torch.nn.InstanceNorm2d: self.norm = torch.nn.InstanceNorm2d(None) elif isinstance(norm, torch.nn.Module): self.norm = norm elif norm == "fused_modconv": self.fused_modconv = True elif norm: self.norm = torch.nn.InstanceNorm2d(None) elif norm is not None: raise ValueError(f"norm not supported: {norm}") self.activation = activation self.conv_clamp = conv_clamp self.out_channels = out_channels self.in_channels = in_channels self.padding = kernel_size // 2 self.repr = dict( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, activation=activation, conv_clamp=conv_clamp, bias=bias, fused_modconv=self.fused_modconv ) self.act_gain = bias_act.activation_funcs[activation].def_gain * gain self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2)) self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size])) self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None self.bias_gain = lr_multiplier if w_dim is not None: if self.fused_modconv: self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) else: self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0) self.gradient_checkpoint_norm = gradient_checkpoint_norm def forward(self, x, w=None, gain=1, **kwargs): if self.fused_modconv: styles = self.affine(w) with torch.cuda.amp.autocast(enabled=False): x = modulated_conv2d(x=x.half(), weight=self.weight.half(), styles=styles.half(), noise=None, padding=self.padding, flip_weight=True, fused_modconv=False).to(x.dtype) else: if hasattr(self, "affine"): gamma = self.affine(w).view(-1, self.in_channels, 1, 1) beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1) x = fma.fma(x, gamma ,beta) w = self.weight * self.weight_gain x = F.conv2d(input=x, weight=w, padding=self.padding,) if hasattr(self, "norm"): if self.gradient_checkpoint_norm: x = checkpoint_fn(self.norm, x) else: x = self.norm(x) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None b = self.bias * self.bias_gain if self.bias is not None else None x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) return x def extra_repr(self) -> str: return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) class CrossAttention(nn.Module): def __init__( self, dim, context_dim, dim_head=64, heads=8, norm_context=False, ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads inner_dim = dim_head * heads self.norm = nn.InstanceNorm1d(dim) self.norm_context = nn.InstanceNorm1d(None) if norm_context else Identity() self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias=False), nn.InstanceNorm1d(None) ) def forward(self, x, w): x = self.norm(x) w = self.norm_context(w) q, k, v = (self.to_q(x), *self.to_kv(w).chunk(2, dim = -1)) q = rearrange(q, "b n (h d) -> b h n d", h = self.heads) k = rearrange(k, "b n (h d) -> b h n d", h = self.heads) v = rearrange(v, "b n (h d) -> b h n d", h = self.heads) q = q * self.scale # similarities sim = einsum('b h i d, b h j d -> b h i j', q, k) attn = sim.softmax(dim = -1, dtype = torch.float32) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class SG2ResidualBlock(torch.nn.Module): def __init__( self, in_channels, # Number of input channels, 0 = first block. out_channels, # Number of output channels. conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. skip_gain=np.sqrt(.5), cross_attention: bool = False, cross_attention_len: int = None, use_adain: bool = True, **layer_kwargs, # Arguments for conv layer. ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels w_dim = layer_kwargs.pop("w_dim") if "w_dim" in layer_kwargs else None if use_adain: layer_kwargs["w_dim"] = w_dim self.conv0 = Conv2d(in_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs) self.conv1 = Conv2d(out_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs, gain=skip_gain) self.skip = Conv2d(in_channels, out_channels, kernel_size=1, bias=False, gain=skip_gain) if cross_attention and w_dim is not None: self.cross_attention_len = cross_attention_len self.cross_attn = CrossAttention( dim=out_channels, context_dim=w_dim//self.cross_attention_len, gain=skip_gain) def forward(self, x, w=None, **layer_kwargs): y = self.skip(x) x = self.conv0(x, w, **layer_kwargs) x = self.conv1(x, w, **layer_kwargs) if hasattr(self, "cross_attn"): h = x.shape[-2] x = rearrange(x, "b c h w -> b (h w) c") w = rearrange(w, "b (n c) -> b n c", n=self.cross_attention_len) x = self.cross_attn(x, w=w) + x x = rearrange(x, "b (h w) c -> b c h w", h=h) return y + x def default(val, d): if val is not None: return val return d() if callable(d) else d def cast_tuple(val, length=None): if isinstance(val, Iterable) and not isinstance(val, str): val = tuple(val) output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) if length is not None: assert len(output) == length, (output, length) return output class Attention(nn.Module): # This is a version of Multi-Query Attention () # Fast Transformer Decoding: One Write-Head is All You Need # Ablated in: https://arxiv.org/pdf/2203.07814.pdf # and https://arxiv.org/pdf/2204.02311.pdf def __init__(self, dim, norm, attn_fix_gain, gradient_checkpoint, dim_head=64, heads=8, cosine_sim_attn=False, fix_attention_again=False, gain=None): super().__init__() self.scale = dim_head**-0.5 if not cosine_sim_attn else 1.0 self.cosine_sim_attn = cosine_sim_attn self.cosine_sim_scale = 16 if cosine_sim_attn else 1 self.gradient_checkpoint = gradient_checkpoint self.heads = heads self.dim = dim self.fix_attention_again = fix_attention_again inner_dim = dim_head * heads if norm == "LN": self.norm = LayerNorm(dim) elif norm == "IN": self.norm = nn.InstanceNorm1d(dim) elif norm is None: self.norm = nn.Identity() else: raise ValueError(f"Norm not supported: {norm}") self.to_q = FullyConnectedLayer(dim, inner_dim, bias=False) self.to_kv = FullyConnectedLayer(dim, dim_head*2, bias=False) self.to_out = nn.Sequential( FullyConnectedLayer(inner_dim, dim, bias=False), LayerNorm(dim) if norm == "LN" else nn.InstanceNorm1d(dim) ) if fix_attention_again: assert gain is not None self.gain = gain else: self.gain = np.sqrt(.5) if attn_fix_gain else 1 def run_function(self, x, attn_bias): b, c, h, w = x.shape x = rearrange(x, "b c h w -> b (h w) c") in_ = x b, n, device = *x.shape[:2], x.device x = self.norm(x) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) q = q * self.scale # calculate query / key similarities sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale if attn_bias is not None: attn_bias = attn_bias attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)") sim = sim + attn_bias attn = sim.softmax(dim=-1) out = einsum("b h i j, b j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") if self.fix_attention_again: out = self.to_out(out)*self.gain + in_ else: out = (self.to_out(out) + in_) * self.gain out = rearrange(out, "b (h w) c -> b c h w", h=h) return out def forward(self, x, *args, attn_bias=None, **kwargs): if self.gradient_checkpoint: return checkpoint_fn(self.run_function, x, attn_bias) return self.run_function(x, attn_bias) def get_attention(self, x, attn_bias=None): b, c, h, w = x.shape x = rearrange(x, "b c h w -> b (h w) c") in_ = x b, n, device = *x.shape[:2], x.device x = self.norm(x) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) q = q * self.scale # calculate query / key similarities sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale if attn_bias is not None: attn_bias = attn_bias attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)") sim = sim + attn_bias attn = sim.softmax(dim=-1) return attn, None class BiasedAttention(Attention): def __init__(self, *args, head_wise: bool=True, **kwargs): super().__init__(*args, **kwargs) out_ch = self.heads if head_wise else 1 self.conv = Conv2d(self.dim+2, out_ch, activation="linear", kernel_size=3, bias_init=0) nn.init.zeros_(self.conv.weight.data) def forward(self, x, mask): mask = resize(mask, size=x.shape[-2:]) bias = self.conv(torch.cat((x, mask, 1-mask), dim=1)) return super().forward(x=x, attn_bias=bias) def get_attention(self, x, mask): mask = resize(mask, size=x.shape[-2:]) bias = self.conv(torch.cat((x, mask, 1-mask), dim=1)) return super().get_attention(x, bias)[0], bias class UNet(BaseGenerator): def __init__( self, im_channels: int, dim: int, dim_mults: tuple, num_resnet_blocks, # Number of resnet blocks per resolution n_middle_blocks: int, z_channels: int, conv_clamp: int, layer_attn, w_dim: int, norm_enc: bool, norm_dec: str, stylenet: nn.Module, enc_style: bool, # Toggle style injection in encoder use_maskrcnn_mask: bool, skip_all_unets: bool, fix_resize:bool, comodulate: bool, comod_net: nn.Module, lr_comod: float, dec_style: bool, input_keypoints: bool, n_keypoints: int, input_keypoint_indices: Tuple[int], use_adain: bool, cross_attention: bool, cross_attention_len: int, gradient_checkpoint_norm: bool, attn_cls: partial, mask_out_train: bool, fix_gain_again: bool, ) -> None: super().__init__(z_channels) self.enc_style = enc_style self.n_keypoints = n_keypoints self.input_keypoint_indices = list(input_keypoint_indices) self.input_keypoints = input_keypoints self.mask_out_train = mask_out_train n_layers = len(dim_mults) self.n_layers = n_layers layer_attn = cast_tuple(layer_attn, n_layers) num_resnet_blocks = cast_tuple(num_resnet_blocks, n_layers) self._cnum = dim self._image_channels = im_channels self._z_channels = z_channels encoder_layers = [] condition_ch = im_channels self.from_rgb = Conv2d( condition_ch + 2 + 2*int(use_maskrcnn_mask) + self.input_keypoints*len(input_keypoint_indices) , dim, 7) self.use_maskrcnn_mask = use_maskrcnn_mask self.skip_all_unets = skip_all_unets dims = [dim*m for m in dim_mults] enc_blk = partial( SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_enc, use_adain=use_adain and self.enc_style, w_dim=w_dim, cross_attention=cross_attention, cross_attention_len=cross_attention_len, gradient_checkpoint_norm=gradient_checkpoint_norm ) dec_blk = partial( SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_dec, use_adain=use_adain and dec_style, w_dim=w_dim, cross_attention=cross_attention, cross_attention_len=cross_attention_len, gradient_checkpoint_norm=gradient_checkpoint_norm ) # Currently up/down sampling is done by bilinear upsampling. # This can be simplified by replacing it with a strided upsampling layer... self.encoder_attns = nn.ModuleList() for lidx in range(n_layers): gain = np.sqrt(1/3) if layer_attn[lidx] and fix_gain_again else np.sqrt(.5) dim_in = dims[lidx] dim_out = dims[min(lidx+1, n_layers-1)] res_blocks = nn.ModuleList() for i in range(num_resnet_blocks[lidx]): is_last = num_resnet_blocks[lidx] - 1 == i cur_dim = dim_out if is_last else dim_in block = enc_blk(dim_in, cur_dim, skip_gain=gain) res_blocks.append(block) if layer_attn[lidx]: self.encoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain)) else: self.encoder_attns.append(Identity()) encoder_layers.append(res_blocks) self.encoder = torch.nn.ModuleList(encoder_layers) # initialize decoder decoder_layers = [] self.unet_layers = torch.nn.ModuleList() self.decoder_attns = torch.nn.ModuleList() for lidx in range(n_layers): dim_in = dims[min(-lidx, -1)] dim_out = dims[-1-lidx] res_blocks = nn.ModuleList() unet_skips = nn.ModuleList() for i in range(num_resnet_blocks[-lidx-1]): is_first = i == 0 has_unet = is_first or skip_all_unets is_last = i == num_resnet_blocks[-lidx-1] - 1 cur_dim = dim_in if is_first else dim_out if has_unet and is_last and layer_attn[-lidx-1] and fix_gain_again: # x + residual + unet + layer attn gain = np.sqrt(1/4) elif has_unet: # x + residual + unet gain = np.sqrt(1/3) elif layer_attn[-lidx-1] and fix_gain_again: # x + residual + attention gain = np.sqrt(1/3) else: # x + residual gain = np.sqrt(1/2) # Only residual block block = dec_blk(cur_dim, dim_out, skip_gain=gain) res_blocks.append(block) if has_unet: unet_block = Conv2d( cur_dim, cur_dim, kernel_size=1, conv_clamp=conv_clamp, norm=nn.InstanceNorm2d(None), gradient_checkpoint_norm=gradient_checkpoint_norm, gain=gain) unet_skips.append(unet_block) else: unet_skips.append(torch.nn.Identity()) if layer_attn[-lidx-1]: self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain)) else: self.decoder_attns.append(Identity()) decoder_layers.append(res_blocks) self.unet_layers.append(unet_skips) middle_blocks = [] for i in range(n_middle_blocks): block = dec_blk(dims[-1], dims[-1]) middle_blocks.append(block) if n_middle_blocks != 0: self.middle_blocks = Sequential(*middle_blocks) self.decoder = torch.nn.ModuleList(decoder_layers) self.to_rgb = Conv2d(dim, im_channels, 1, activation="linear", conv_clamp=conv_clamp) self.stylenet = stylenet self.downsample = Upfirdn2d(down=2, fix_gain=fix_resize) self.upsample = Upfirdn2d(up=2, fix_gain=fix_resize) self.comodulate = comodulate if comodulate: assert not self.enc_style self.to_y = nn.Sequential( Conv2d(dims[-1], dims[-1], lr_multiplier=lr_comod, gradient_checkpoint_norm=gradient_checkpoint_norm), nn.AdaptiveAvgPool2d(1), nn.Flatten(), FullyConnectedLayer(dims[-1], 512, activation="lrelu", lr_multiplier=lr_comod) ) self.comod_net = comod_net def forward(self, condition, mask, maskrcnn_mask=None, z=None, w=None, update_emas=False, keypoints=None, return_decoder_features=False, **kwargs): if z is None: z = self.get_z(condition) if w is None: w = self.stylenet(z, update_emas=update_emas) if self.use_maskrcnn_mask: x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) else: x = torch.cat((condition, mask, 1-mask), dim=1) if self.input_keypoints: keypoints = keypoints[:, self.input_keypoint_indices] one_hot_pose = spatial_embed_keypoints(keypoints, x) x = torch.cat((x, one_hot_pose), dim=1) x = self.from_rgb(x) x, unet_features = self.forward_enc(x, mask, w) x, decoder_features = self.forward_dec(x, mask, w, unet_features) x = self.to_rgb(x) unmasked = x if self.mask_out_train: x = mask * condition + (1-mask) * x out = dict(img=x, unmasked=unmasked) if return_decoder_features: out["decoder_features"] = decoder_features return out def forward_enc(self, x, mask, w): unet_features = [] for i, res_blocks in enumerate(self.encoder): is_last = i == len(self.encoder) - 1 for block in res_blocks: x = block(x, w=w) unet_features.append(x) x = self.encoder_attns[i](x, mask=mask) if not is_last: x = self.downsample(x) if self.comodulate: y = self.to_y(x) y = torch.cat((w, y), dim=-1) w = self.comod_net(y) return x, unet_features def forward_dec(self, x, mask, w, unet_features): if hasattr(self, "middle_blocks"): x = self.middle_blocks(x, w=w) features = [] unet_features = iter(reversed(unet_features)) for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)): is_last = i == len(self.decoder) - 1 for skip, block in zip(unet_skip, res_blocks): skip_x = next(unet_features) if not isinstance(skip, torch.nn.Identity): skip_x = skip(skip_x) x = x + skip_x x = block(x, w=w) x = self.decoder_attns[i](x, mask=mask) features.append(x) if not is_last: x = self.upsample(x) return x, features def get_w(self, z, update_emas): return self.stylenet(z, update_emas=update_emas) @torch.no_grad() def sample(self, truncation_value, **kwargs): if truncation_value is None: return self.forward(**kwargs) truncation_value = max(0, truncation_value) truncation_value = min(truncation_value, 1) w = self.get_w(self.get_z(kwargs["condition"]), False) w = self.stylenet.w_avg.to(w.dtype).lerp(w, truncation_value) return self.forward(**kwargs, w=w) def update_w(self, *args, **kwargs): self.style_net.update_w(*args, **kwargs) @property def style_net(self): return self.stylenet @torch.no_grad() def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): if truncation_value is None: return self.forward(**kwargs) truncation_value = max(0, truncation_value) truncation_value = min(truncation_value, 1) w = self.get_w(self.get_z(kwargs["condition"]), False) if w_indices is None: w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) w_centers = self.style_net.w_centers[w_indices].to(w.device) w = w_centers.to(w.dtype).lerp(w, truncation_value) return self.forward(**kwargs, w=w) def get_stem_unet_kwargs(cfg): if "stem_cfg" in cfg.generator: # If the stem has another stem, recursively apply get_stem_unet_kwargs return get_stem_unet_kwargs(cfg.generator.stem_cfg) return dict(cfg.generator) class GrowingUnet(BaseGenerator): def __init__( self, coarse_stem_cfg: str, # This can be a coarse generator or None sr_cfg: str, # Can be a previous progressive u-net, Unet or None residual: bool, new_dataset: bool, # The "new dataset" creates condition first -> resizes **unet_kwargs): kwargs = dict() if coarse_stem_cfg is not None: coarse_stem_cfg = utils.load_config(coarse_stem_cfg) kwargs = get_stem_unet_kwargs(coarse_stem_cfg) if sr_cfg is not None: sr_cfg = utils.load_config(sr_cfg) sr_stem_unet_kwargs = get_stem_unet_kwargs(sr_cfg) kwargs.update(sr_stem_unet_kwargs) kwargs.update(unet_kwargs) kwargs["stylenet"] = None kwargs.pop("_target_") if "sr_cfg" in kwargs: # Unet kwargs are inherited, do not pass this to the new u-net del kwargs["sr_cfg"] if "coarse_stem_cfg" in kwargs: del kwargs["coarse_stem_cfg"] super().__init__(z_channels=kwargs["z_channels"]) if coarse_stem_cfg is not None: z_channels = coarse_stem_cfg.generator.z_channels super().__init__(z_channels) self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval() self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize) utils.set_requires_grad(self.coarse_stem, False) else: assert not residual if sr_cfg is not None: self.sr_stem = infer.build_trained_generator(sr_cfg, map_location="cpu").eval() del self.sr_stem.from_rgb del self.sr_stem.to_rgb if hasattr(self.sr_stem, "coarse_stem"): del self.sr_stem.coarse_stem if isinstance(self.sr_stem, UNet): del self.sr_stem.encoder[0][0] # Delete first residual block del self.sr_stem.decoder[-1][-1] # Delete last residual block else: assert isinstance(self.sr_stem, GrowingUnet) del self.sr_stem.unet.encoder[0][0] # Delete first residual block del self.sr_stem.unet.decoder[-1][-1] # Delete last residual block utils.set_requires_grad(self.sr_stem, False) args = kwargs.pop("_args_") if hasattr(self, "sr_stem"): # Growing the SR stem - Add a new layer to match sr n_layers = len(kwargs["dim_mults"]) dim_mult = sr_stem_unet_kwargs["dim"] / (kwargs["dim"] * max(kwargs["dim_mults"])) kwargs["dim_mults"] = [*kwargs["dim_mults"], int(dim_mult)] kwargs["layer_attn"] = [*cast_tuple(kwargs["layer_attn"], n_layers), False] kwargs["num_resnet_blocks"] = [*cast_tuple(kwargs["num_resnet_blocks"], n_layers), 1] self.unet = UNet( *args, **kwargs ) self.from_rgb = self.unet.from_rgb self.to_rgb = self.unet.to_rgb self.residual = residual self.new_dataset = new_dataset if residual: nn.init.zeros_(self.to_rgb.weight.data) del self.unet.from_rgb, self.unet.to_rgb def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, **kwargs): # Downsample for stem if z is None: z = self.get_z(img) if w is None: w = self.style_net(z) if hasattr(self, "coarse_stem"): with torch.no_grad(): if self.new_dataset: img_stem = utils.denormalize_img(img)*255 condition_stem = img_stem * mask + (1-mask)*127 condition_stem = condition_stem.round() condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True) condition_stem = condition_stem / 255 *2 - 1 mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float() maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float() else: mask_stem = (resize(mask, self.coarse_stem.imsize, antialias=True) > .99).float() maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, antialias=True) > .5).float() img_stem = utils.denormalize_img(img)*255 img_stem = resize(img_stem, self.coarse_stem.imsize, antialias=True).round() img_stem = img_stem / 255 * 2 - 1 condition_stem = img_stem * mask_stem stem_out = self.coarse_stem( condition=condition_stem, mask=mask_stem, maskrcnn_mask=maskrcnn_stem, w=w, keypoints=keypoints) x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True) condition = condition*mask + (1-mask) * x_lr if self.unet.use_maskrcnn_mask: x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) else: x = torch.cat((condition, mask, 1-mask), dim=1) if self.unet.input_keypoints: keypoints = keypoints[:, self.unet.input_keypoint_indices] one_hot_pose = spatial_embed_keypoints(keypoints, x) x = torch.cat((x, one_hot_pose), dim=1) x = self.from_rgb(x) x, unet_features = self.forward_enc(x, mask, w) x = self.forward_dec(x, mask, w, unet_features) if self.residual: x = self.to_rgb(x) + condition else: x = self.to_rgb(x) return dict( img=condition * mask + (1-mask) * x, unmasked=x, x_lowres=[condition] ) def forward_enc(self, x, mask, w): x, unet_features = self.unet.forward_enc(x, mask, w) if hasattr(self, "sr_stem"): x, unet_features_stem = self.sr_stem.forward_enc(x, mask, w) else: unet_features_stem = None return x, [unet_features, unet_features_stem] def forward_dec(self, x, mask, w, unet_features): unet_features, unet_features_stem = unet_features if hasattr(self, "sr_stem"): x = self.sr_stem.forward_dec(x, mask, w, unet_features_stem) x, unet_features = self.unet.forward_dec(x, mask, w, unet_features) return x def get_z(self, *args, **kwargs): if hasattr(self, "coarse_stem"): return self.coarse_stem.get_z(*args, **kwargs) if hasattr(self, "sr_stem"): return self.sr_stem.get_z(*args, **kwargs) raise AttributeError() @property def style_net(self): if hasattr(self, "coarse_stem"): return self.coarse_stem.style_net if hasattr(self, "sr_stem"): return self.sr_stem.style_net raise AttributeError() def update_w(self, *args, **kwargs): self.style_net.update_w(*args, **kwargs) def get_w(self, z, update_emas): return self.style_net(z, update_emas=update_emas) @torch.no_grad() def sample(self, truncation_value, **kwargs): if truncation_value is None: return self.forward(**kwargs) truncation_value = max(0, truncation_value) truncation_value = min(truncation_value, 1) w = self.get_w(self.get_z(kwargs["condition"]), False) w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value) return self.forward(**kwargs, w=w) @torch.no_grad() def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): if truncation_value is None: return self.forward(**kwargs) truncation_value = max(0, truncation_value) truncation_value = min(truncation_value, 1) w = self.get_w(self.get_z(kwargs["condition"]), False) if w_indices is None: w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) w_centers = self.style_net.w_centers[w_indices].to(w.device) w = w_centers.to(w.dtype).lerp(w, truncation_value) return self.forward(**kwargs, w=w) class CascadedUnet(BaseGenerator): def __init__( self, coarse_stem_cfg: str, # This can be a coarse generator or None residual: bool, new_dataset: bool, # The "new dataset" creates condition first -> resizes imsize: tuple, cascade:bool, **unet_kwargs): kwargs = dict() coarse_stem_cfg = utils.load_config(coarse_stem_cfg) kwargs = get_stem_unet_kwargs(coarse_stem_cfg) kwargs.update(unet_kwargs) super().__init__(z_channels=kwargs["z_channels"]) self.input_keypoints = kwargs["input_keypoints"] self.input_keypoint_indices = kwargs["input_keypoint_indices"] self.use_maskrcnn_mask = kwargs["use_maskrcnn_mask"] self.imsize = imsize self.residual = residual self.new_dataset = new_dataset # Setup coarse stem stem_dims = [m*coarse_stem_cfg.generator.dim for m in coarse_stem_cfg.generator.dim_mults] self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval() self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize) utils.set_requires_grad(self.coarse_stem, False) self.stem_res_to_layer_idx = { self.coarse_stem.imsize[0] // 2^i: stem_dims[i] for i in range(len(stem_dims)) } dim = kwargs["dim"] dim_mults = kwargs["dim_mults"] n_layers = len(dim_mults) dims = [dim*s for s in dim_mults] layer_attn = cast_tuple(kwargs["layer_attn"], n_layers) num_resnet_blocks = cast_tuple(kwargs["num_resnet_blocks"], n_layers) attn_cls = kwargs["attn_cls"] if not isinstance(attn_cls, partial): attn_cls = instantiate(attn_cls) dec_blk = partial( SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_dec"], use_adain=kwargs["use_adain"] and kwargs["dec_style"], w_dim=kwargs["w_dim"], cross_attention=kwargs["cross_attention"], cross_attention_len=kwargs["cross_attention_len"], gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"] ) enc_blk = partial( SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_enc"], use_adain=kwargs["use_adain"] and kwargs["enc_style"], w_dim=kwargs["w_dim"], cross_attention=kwargs["cross_attention"], cross_attention_len=kwargs["cross_attention_len"], gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"] ) # Currently up/down sampling is done by bilinear upsampling. # This can be simplified by replacing it with a strided upsampling layer... self.encoder_attns = nn.ModuleList() self.encoder_unet_skips = nn.ModuleDict() self.encoder = nn.ModuleList() for lidx in range(n_layers): has_stem_feature = imsize[0]//2^lidx in self.stem_res_to_layer_idx and cascade next_layer_has_stem_features = lidx+1 < n_layers and imsize[0]//2^(lidx+1) in self.stem_res_to_layer_idx and cascade dim_in = dims[lidx] dim_out = dims[min(lidx+1, n_layers-1)] res_blocks = nn.ModuleList() if has_stem_feature: prev_layer_has_attention = lidx != 0 and layer_attn[lidx-1] stem_lidx = self.stem_res_to_layer_idx[imsize[0]//2^lidx] self.encoder_unet_skips.add_module( str(imsize[0]//2^lidx), Conv2d( stem_dims[stem_lidx], dim_in, kernel_size=1, conv_clamp=kwargs["conv_clamp"], norm=nn.InstanceNorm2d(None), gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"], gain=np.sqrt(1/4) if prev_layer_has_attention else np.sqrt(1/3) # This + previous residual + attention ) ) for i in range(num_resnet_blocks[lidx]): is_last = num_resnet_blocks[lidx] - 1 == i cur_dim = dim_out if is_last else dim_in if not is_last: gain = np.sqrt(.5) elif next_layer_has_stem_features and layer_attn[lidx]: gain = np.sqrt(1/4) elif layer_attn[lidx] or next_layer_has_stem_features: gain = np.sqrt(1/3) else: gain = np.sqrt(.5) block = enc_blk(dim_in, cur_dim, skip_gain=gain) res_blocks.append(block) if layer_attn[lidx]: self.encoder_attns.append(attn_cls(dim=dim_out, gain=gain, fix_attention_again=True)) else: self.encoder_attns.append(Identity()) self.encoder.append(res_blocks) # initialize decoder self.decoder = torch.nn.ModuleList() self.unet_layers = torch.nn.ModuleList() self.decoder_attns = torch.nn.ModuleList() for lidx in range(n_layers): dim_in = dims[min(-lidx, -1)] dim_out = dims[-1-lidx] res_blocks = nn.ModuleList() unet_skips = nn.ModuleList() for i in range(num_resnet_blocks[-lidx-1]): is_first = i == 0 has_unet = is_first or kwargs["skip_all_unets"] is_last = i == num_resnet_blocks[-lidx-1] - 1 cur_dim = dim_in if is_first else dim_out if has_unet and is_last and layer_attn[-lidx-1]: # x + residual + unet + layer attn gain = np.sqrt(1/4) elif has_unet: # x + residual + unet gain = np.sqrt(1/3) elif layer_attn[-lidx-1]: # x + residual + attention gain = np.sqrt(1/3) else: # x + residual gain = np.sqrt(1/2) # Only residual block block = dec_blk(cur_dim, dim_out, skip_gain=gain) res_blocks.append(block) if kwargs["skip_all_unets"] or is_first: unet_block = Conv2d( cur_dim, cur_dim, kernel_size=1, conv_clamp=kwargs["conv_clamp"], norm=nn.InstanceNorm2d(None), gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"], gain=gain) unet_skips.append(unet_block) else: unet_skips.append(torch.nn.Identity()) if layer_attn[-lidx-1]: self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=True, gain=gain)) else: self.decoder_attns.append(Identity()) self.decoder.append(res_blocks) self.unet_layers.append(unet_skips) self.from_rgb = Conv2d( 3 + 2 + 2*int(kwargs["use_maskrcnn_mask"]) + self.input_keypoints*len(kwargs["input_keypoint_indices"]) , dim, 7) self.to_rgb = Conv2d(dim, 3, 1, activation="linear", conv_clamp=kwargs["conv_clamp"]) self.downsample = Upfirdn2d(down=2, fix_gain=True) self.upsample = Upfirdn2d(up=2, fix_gain=True) self.cascade = cascade if residual: nn.init.zeros_(self.to_rgb.weight.data) def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, return_decoder_features=False, **kwargs): # Downsample for stem if z is None: z = self.get_z(img) with torch.no_grad(): # Forward pass stem if w is None: w = self.style_net(z) img_stem = utils.denormalize_img(img)*255 condition_stem = img_stem * mask + (1-mask)*127 condition_stem = condition_stem.round() condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True) condition_stem = condition_stem / 255 *2 - 1 mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float() maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float() stem_out = self.coarse_stem( condition=condition_stem, mask=mask_stem, maskrcnn_mask=maskrcnn_stem, w=w, keypoints=keypoints, return_decoder_features=True) stem_features = stem_out["decoder_features"] x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True) condition = condition*mask + (1-mask) * x_lr if self.use_maskrcnn_mask: x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) else: x = torch.cat((condition, mask, 1-mask), dim=1) if self.input_keypoints: keypoints = keypoints[:, self.input_keypoint_indices] one_hot_pose = spatial_embed_keypoints(keypoints, x) x = torch.cat((x, one_hot_pose), dim=1) x = self.from_rgb(x) x, unet_features = self.forward_enc(x, mask, w, stem_features) x, decoder_features = self.forward_dec(x, mask, w, unet_features) if self.residual: x = self.to_rgb(x) + condition else: x = self.to_rgb(x) out= dict( img=condition * mask + (1-mask) * x, # TODO: Probably do not want masked here... or ?? unmasked=x, x_lowres=[condition] ) if return_decoder_features: out["decoder_features"] = decoder_features return out def forward_enc(self, x, mask, w, stem_features: List[torch.Tensor]): unet_features = [] stem_features.reverse() for i, res_blocks in enumerate(self.encoder): is_last = i == len(self.encoder) - 1 res = self.imsize[0]//2^i if str(res) in self.encoder_unet_skips.keys() and self.cascade: y = stem_features[self.stem_res_to_layer_idx[res]] y = self.encoder_unet_skips[i](y) x = y + x for block in res_blocks: x = block(x, w=w) unet_features.append(x) x = self.encoder_attns[i](x, mask) if not is_last: x = self.downsample(x) return x, unet_features def forward_dec(self, x, mask, w, unet_features): features = [] unet_features = iter(reversed(unet_features)) for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)): is_last = i == len(self.decoder) - 1 for skip, block in zip(unet_skip, res_blocks): skip_x = next(unet_features) if not isinstance(skip, torch.nn.Identity): skip_x = skip(skip_x) x = x + skip_x x = block(x, w=w) x = self.decoder_attns[i](x, mask) features.append(x) if not is_last: x = self.upsample(x) return x, features def get_z(self, *args, **kwargs): return self.coarse_stem.get_z(*args, **kwargs) @property def style_net(self): return self.coarse_stem.style_net def update_w(self, *args, **kwargs): self.style_net.update_w(*args, **kwargs) def get_w(self, z, update_emas): return self.style_net(z, update_emas=update_emas) @torch.no_grad() def sample(self, truncation_value, **kwargs): if truncation_value is None: return self.forward(**kwargs) truncation_value = max(0, truncation_value) truncation_value = min(truncation_value, 1) w = self.get_w(self.get_z(kwargs["condition"]), False) w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value) return self.forward(**kwargs, w=w) @torch.no_grad() def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): if truncation_value is None: return self.forward(**kwargs) truncation_value = max(0, truncation_value) truncation_value = min(truncation_value, 1) w = self.get_w(self.get_z(kwargs["condition"]), False) if w_indices is None: w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) w_centers = self.style_net.w_centers[w_indices].to(w.device) w = w_centers.to(w.dtype).lerp(w, truncation_value) return self.forward(**kwargs, w=w)