deep_privacy2 / dp2 /generator /imagen3_old.py
haakohu's picture
:)
548d634
raw
history blame
52.1 kB
# What is missing from this implementation
# 1. Global context in res block
# 2. Cross attention of conditional information in resnet block
#
from functools import partial
import tops
from tops.config import instantiate
import warnings
from typing import Iterable, List, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch import einsum
from einops import rearrange
from dp2 import infer, utils
from .base import BaseGenerator
from sg3_torch_utils.ops import bias_act
from dp2.layers import Sequential
import torch.nn.functional as F
from torchvision.transforms.functional import resize, InterpolationMode
from sg3_torch_utils.ops import conv2d_resample, fma, upfirdn2d
class Upfirdn2d(torch.nn.Module):
def __init__(self, down=1, up=1, fix_gain=True):
super().__init__()
self.register_buffer("resample_filter", upfirdn2d.setup_filter([1, 3, 3, 1]))
fw, fh = upfirdn2d._get_filter_size(self.resample_filter)
px0, px1, py0, py1 = upfirdn2d._parse_padding(0)
self.down = down
self.up = up
if up > 1:
px0 += (fw + up - 1) // 2
px1 += (fw - up) // 2
py0 += (fh + up - 1) // 2
py1 += (fh - up) // 2
if down > 1:
px0 += (fw - down + 1) // 2
px1 += (fw - down) // 2
py0 += (fh - down + 1) // 2
py1 += (fh - down) // 2
self.padding = [px0,px1,py0,py1]
self.gain = up**2 if fix_gain else 1
def forward(self, x, *args):
if isinstance(x, dict):
x = {k: v for k, v in x.items()}
x["x"] = upfirdn2d.upfirdn2d(x["x"], self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain)
return x
x = upfirdn2d.upfirdn2d(x, self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain)
if len(args) == 0:
return x
return (x, *args)
@torch.no_grad()
def spatial_embed_keypoints(keypoints: torch.Tensor, x):
tops.assert_shape(keypoints, (None, None, 3))
B, N_K, _ = keypoints.shape
H, W = x.shape[-2:]
keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32)
x, y, visible = keypoints.chunk(3, dim=2)
x = (x * W).round().long().clamp(0, W-1)
y = (y * H).round().long().clamp(0, H-1)
kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1)
pos = (kp_idx*(H*W) + y*W + x + 1)
# Offset all by 1 to index invisible keypoints as 0
pos = (pos * visible.round().long()).squeeze(dim=-1)
keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32)
keypoint_spatial.scatter_(1, pos, 1)
keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W)
return keypoint_spatial
def modulated_conv2d(
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
styles, # Modulation coefficients of shape [batch_size, in_channels].
noise = None, # Optional noise tensor to add to the output activations.
up = 1, # Integer upsampling factor.
down = 1, # Integer downsampling factor.
padding = 0, # Padding with respect to the upsampled image.
resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
demodulate = True, # Apply weight demodulation?
flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
):
batch_size = x.shape[0]
out_channels, in_channels, kh, kw = weight.shape
tops.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
tops.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
tops.assert_shape(styles, [batch_size, in_channels]) # [NI]
# Pre-normalize inputs to avoid FP16 overflow.
if x.dtype == torch.float16 and demodulate:
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
# Calculate per-sample weights and demodulation coefficients.
w = None
dcoefs = None
if demodulate or fused_modconv:
w = weight.unsqueeze(0) # [NOIkk]
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
if demodulate:
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
if demodulate and fused_modconv:
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
# Execute by scaling the activations before and after the convolution.
if not fused_modconv:
x = x * styles.reshape(batch_size, -1, 1, 1)
x = conv2d_resample.conv2d_resample(x=x, w=weight, f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
if demodulate and noise is not None:
x = fma.fma(x, dcoefs.reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
elif demodulate:
x = x * dcoefs.reshape(batch_size, -1, 1, 1)
elif noise is not None:
x = x.add_(noise.to(x.dtype))
return x
with tops.suppress_tracer_warnings(): # this value will be treated as a constant
batch_size = int(batch_size)
# Execute as one fused op using grouped convolution.
tops.assert_shape(x, [batch_size, in_channels, None, None])
x = x.reshape(1, -1, *x.shape[2:])
w = w.reshape(-1, in_channels, kh, kw)
x = conv2d_resample.conv2d_resample(x=x, w=w, f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
x = x.reshape(batch_size, -1, *x.shape[2:])
if noise is not None:
x = x.add_(noise)
return x
class Identity(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, *args, **kwargs):
return x
class LayerNorm(nn.Module):
def __init__(self, dim, stable=False):
super().__init__()
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
if self.stable:
x = x / x.amax(dim=-1, keepdim=True).detach()
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
var = torch.var(x, dim=-1, unbiased=False, keepdim=True)
mean = torch.mean(x, dim=-1, keepdim=True)
return (x - mean) * (var + eps).rsqrt() * self.g
class FullyConnectedLayer(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
out_features, # Number of output features.
bias = True, # Apply additive bias before the activation function?
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier = 1, # Learning rate multiplier.
bias_init = 0, # Initial value for the additive bias.
):
super().__init__()
self.repr = dict(
in_features=in_features, out_features=out_features, bias=bias,
activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init)
self.activation = activation
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
self.weight_gain = lr_multiplier / np.sqrt(in_features)
self.bias_gain = lr_multiplier
self.in_features = in_features
self.out_features = out_features
def forward(self, x):
w = self.weight * self.weight_gain
b = self.bias
if b is not None:
if self.bias_gain != 1:
b = b * self.bias_gain
x = F.linear(x, w)
x = bias_act.bias_act(x, b, act=self.activation)
return x
def extra_repr(self) -> str:
return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
def checkpoint_fn(fn, *args, **kwargs):
warnings.simplefilter("ignore")
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
class Conv2d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
activation='lrelu',
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
bias=True,
norm=None,
lr_multiplier=1,
bias_init=0,
w_dim=None,
gradient_checkpoint_norm=False,
gain=1,
):
super().__init__()
self.fused_modconv = False
if norm == torch.nn.InstanceNorm2d:
self.norm = torch.nn.InstanceNorm2d(None)
elif isinstance(norm, torch.nn.Module):
self.norm = norm
elif norm == "fused_modconv":
self.fused_modconv = True
elif norm:
self.norm = torch.nn.InstanceNorm2d(None)
elif norm is not None:
raise ValueError(f"norm not supported: {norm}")
self.activation = activation
self.conv_clamp = conv_clamp
self.out_channels = out_channels
self.in_channels = in_channels
self.padding = kernel_size // 2
self.repr = dict(
in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size,
activation=activation, conv_clamp=conv_clamp, bias=bias,
fused_modconv=self.fused_modconv
)
self.act_gain = bias_act.activation_funcs[activation].def_gain * gain
self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2))
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]))
self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None
self.bias_gain = lr_multiplier
if w_dim is not None:
if self.fused_modconv:
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
else:
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0)
self.gradient_checkpoint_norm = gradient_checkpoint_norm
def forward(self, x, w=None, gain=1, **kwargs):
if self.fused_modconv:
styles = self.affine(w)
with torch.cuda.amp.autocast(enabled=False):
x = modulated_conv2d(x=x.half(), weight=self.weight.half(), styles=styles.half(), noise=None,
padding=self.padding, flip_weight=True, fused_modconv=False).to(x.dtype)
else:
if hasattr(self, "affine"):
gamma = self.affine(w).view(-1, self.in_channels, 1, 1)
beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1)
x = fma.fma(x, gamma ,beta)
w = self.weight * self.weight_gain
x = F.conv2d(input=x, weight=w, padding=self.padding,)
if hasattr(self, "norm"):
if self.gradient_checkpoint_norm:
x = checkpoint_fn(self.norm, x)
else:
x = self.norm(x)
act_gain = self.act_gain * gain
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
b = self.bias * self.bias_gain if self.bias is not None else None
x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
return x
def extra_repr(self) -> str:
return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
class CrossAttention(nn.Module):
def __init__(
self,
dim,
context_dim,
dim_head=64,
heads=8,
norm_context=False,
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.InstanceNorm1d(dim)
self.norm_context = nn.InstanceNorm1d(None) if norm_context else Identity()
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias=False),
nn.InstanceNorm1d(None)
)
def forward(self, x, w):
x = self.norm(x)
w = self.norm_context(w)
q, k, v = (self.to_q(x), *self.to_kv(w).chunk(2, dim = -1))
q = rearrange(q, "b n (h d) -> b h n d", h = self.heads)
k = rearrange(k, "b n (h d) -> b h n d", h = self.heads)
v = rearrange(v, "b n (h d) -> b h n d", h = self.heads)
q = q * self.scale
# similarities
sim = einsum('b h i d, b h j d -> b h i j', q, k)
attn = sim.softmax(dim = -1, dtype = torch.float32)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class SG2ResidualBlock(torch.nn.Module):
def __init__(
self,
in_channels, # Number of input channels, 0 = first block.
out_channels, # Number of output channels.
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
skip_gain=np.sqrt(.5),
cross_attention: bool = False,
cross_attention_len: int = None,
use_adain: bool = True,
**layer_kwargs, # Arguments for conv layer.
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
w_dim = layer_kwargs.pop("w_dim") if "w_dim" in layer_kwargs else None
if use_adain:
layer_kwargs["w_dim"] = w_dim
self.conv0 = Conv2d(in_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs)
self.conv1 = Conv2d(out_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs, gain=skip_gain)
self.skip = Conv2d(in_channels, out_channels, kernel_size=1, bias=False, gain=skip_gain)
if cross_attention and w_dim is not None:
self.cross_attention_len = cross_attention_len
self.cross_attn = CrossAttention(
dim=out_channels, context_dim=w_dim//self.cross_attention_len,
gain=skip_gain)
def forward(self, x, w=None, **layer_kwargs):
y = self.skip(x)
x = self.conv0(x, w, **layer_kwargs)
x = self.conv1(x, w, **layer_kwargs)
if hasattr(self, "cross_attn"):
h = x.shape[-2]
x = rearrange(x, "b c h w -> b (h w) c")
w = rearrange(w, "b (n c) -> b n c", n=self.cross_attention_len)
x = self.cross_attn(x, w=w) + x
x = rearrange(x, "b (h w) c -> b c h w", h=h)
return y + x
def default(val, d):
if val is not None:
return val
return d() if callable(d) else d
def cast_tuple(val, length=None):
if isinstance(val, Iterable) and not isinstance(val, str):
val = tuple(val)
output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
if length is not None:
assert len(output) == length, (output, length)
return output
class Attention(nn.Module):
# This is a version of Multi-Query Attention ()
# Fast Transformer Decoding: One Write-Head is All You Need
# Ablated in: https://arxiv.org/pdf/2203.07814.pdf
# and https://arxiv.org/pdf/2204.02311.pdf
def __init__(self, dim, norm, attn_fix_gain, gradient_checkpoint, dim_head=64, heads=8, cosine_sim_attn=False, fix_attention_again=False, gain=None):
super().__init__()
self.scale = dim_head**-0.5 if not cosine_sim_attn else 1.0
self.cosine_sim_attn = cosine_sim_attn
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
self.gradient_checkpoint = gradient_checkpoint
self.heads = heads
self.dim = dim
self.fix_attention_again = fix_attention_again
inner_dim = dim_head * heads
if norm == "LN":
self.norm = LayerNorm(dim)
elif norm == "IN":
self.norm = nn.InstanceNorm1d(dim)
elif norm is None:
self.norm = nn.Identity()
else:
raise ValueError(f"Norm not supported: {norm}")
self.to_q = FullyConnectedLayer(dim, inner_dim, bias=False)
self.to_kv = FullyConnectedLayer(dim, dim_head*2, bias=False)
self.to_out = nn.Sequential(
FullyConnectedLayer(inner_dim, dim, bias=False),
LayerNorm(dim) if norm == "LN" else nn.InstanceNorm1d(dim)
)
if fix_attention_again:
assert gain is not None
self.gain = gain
else:
self.gain = np.sqrt(.5) if attn_fix_gain else 1
def run_function(self, x, attn_bias):
b, c, h, w = x.shape
x = rearrange(x, "b c h w -> b (h w) c")
in_ = x
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
q = q * self.scale
# calculate query / key similarities
sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale
if attn_bias is not None:
attn_bias = attn_bias
attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)")
sim = sim + attn_bias
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
if self.fix_attention_again:
out = self.to_out(out)*self.gain + in_
else:
out = (self.to_out(out) + in_) * self.gain
out = rearrange(out, "b (h w) c -> b c h w", h=h)
return out
def forward(self, x, *args, attn_bias=None, **kwargs):
if self.gradient_checkpoint:
return checkpoint_fn(self.run_function, x, attn_bias)
return self.run_function(x, attn_bias)
def get_attention(self, x, attn_bias=None):
b, c, h, w = x.shape
x = rearrange(x, "b c h w -> b (h w) c")
in_ = x
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
q = rearrange(q, "b n (h d) -> b h n d", h=self.heads)
q = q * self.scale
# calculate query / key similarities
sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale
if attn_bias is not None:
attn_bias = attn_bias
attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)")
sim = sim + attn_bias
attn = sim.softmax(dim=-1)
return attn, None
class BiasedAttention(Attention):
def __init__(self, *args, head_wise: bool=True, **kwargs):
super().__init__(*args, **kwargs)
out_ch = self.heads if head_wise else 1
self.conv = Conv2d(self.dim+2, out_ch, activation="linear", kernel_size=3, bias_init=0)
nn.init.zeros_(self.conv.weight.data)
def forward(self, x, mask):
mask = resize(mask, size=x.shape[-2:])
bias = self.conv(torch.cat((x, mask, 1-mask), dim=1))
return super().forward(x=x, attn_bias=bias)
def get_attention(self, x, mask):
mask = resize(mask, size=x.shape[-2:])
bias = self.conv(torch.cat((x, mask, 1-mask), dim=1))
return super().get_attention(x, bias)[0], bias
class UNet(BaseGenerator):
def __init__(
self,
im_channels: int,
dim: int,
dim_mults: tuple,
num_resnet_blocks, # Number of resnet blocks per resolution
n_middle_blocks: int,
z_channels: int,
conv_clamp: int,
layer_attn,
w_dim: int,
norm_enc: bool,
norm_dec: str,
stylenet: nn.Module,
enc_style: bool, # Toggle style injection in encoder
use_maskrcnn_mask: bool,
skip_all_unets: bool,
fix_resize:bool,
comodulate: bool,
comod_net: nn.Module,
lr_comod: float,
dec_style: bool,
input_keypoints: bool,
n_keypoints: int,
input_keypoint_indices: Tuple[int],
use_adain: bool,
cross_attention: bool,
cross_attention_len: int,
gradient_checkpoint_norm: bool,
attn_cls: partial,
mask_out_train: bool,
fix_gain_again: bool,
) -> None:
super().__init__(z_channels)
self.enc_style = enc_style
self.n_keypoints = n_keypoints
self.input_keypoint_indices = list(input_keypoint_indices)
self.input_keypoints = input_keypoints
self.mask_out_train = mask_out_train
n_layers = len(dim_mults)
self.n_layers = n_layers
layer_attn = cast_tuple(layer_attn, n_layers)
num_resnet_blocks = cast_tuple(num_resnet_blocks, n_layers)
self._cnum = dim
self._image_channels = im_channels
self._z_channels = z_channels
encoder_layers = []
condition_ch = im_channels
self.from_rgb = Conv2d(
condition_ch + 2 + 2*int(use_maskrcnn_mask) + self.input_keypoints*len(input_keypoint_indices)
, dim, 7)
self.use_maskrcnn_mask = use_maskrcnn_mask
self.skip_all_unets = skip_all_unets
dims = [dim*m for m in dim_mults]
enc_blk = partial(
SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_enc,
use_adain=use_adain and self.enc_style,
w_dim=w_dim,
cross_attention=cross_attention,
cross_attention_len=cross_attention_len,
gradient_checkpoint_norm=gradient_checkpoint_norm
)
dec_blk = partial(
SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_dec,
use_adain=use_adain and dec_style,
w_dim=w_dim,
cross_attention=cross_attention,
cross_attention_len=cross_attention_len,
gradient_checkpoint_norm=gradient_checkpoint_norm
)
# Currently up/down sampling is done by bilinear upsampling.
# This can be simplified by replacing it with a strided upsampling layer...
self.encoder_attns = nn.ModuleList()
for lidx in range(n_layers):
gain = np.sqrt(1/3) if layer_attn[lidx] and fix_gain_again else np.sqrt(.5)
dim_in = dims[lidx]
dim_out = dims[min(lidx+1, n_layers-1)]
res_blocks = nn.ModuleList()
for i in range(num_resnet_blocks[lidx]):
is_last = num_resnet_blocks[lidx] - 1 == i
cur_dim = dim_out if is_last else dim_in
block = enc_blk(dim_in, cur_dim, skip_gain=gain)
res_blocks.append(block)
if layer_attn[lidx]:
self.encoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain))
else:
self.encoder_attns.append(Identity())
encoder_layers.append(res_blocks)
self.encoder = torch.nn.ModuleList(encoder_layers)
# initialize decoder
decoder_layers = []
self.unet_layers = torch.nn.ModuleList()
self.decoder_attns = torch.nn.ModuleList()
for lidx in range(n_layers):
dim_in = dims[min(-lidx, -1)]
dim_out = dims[-1-lidx]
res_blocks = nn.ModuleList()
unet_skips = nn.ModuleList()
for i in range(num_resnet_blocks[-lidx-1]):
is_first = i == 0
has_unet = is_first or skip_all_unets
is_last = i == num_resnet_blocks[-lidx-1] - 1
cur_dim = dim_in if is_first else dim_out
if has_unet and is_last and layer_attn[-lidx-1] and fix_gain_again: # x + residual + unet + layer attn
gain = np.sqrt(1/4)
elif has_unet: # x + residual + unet
gain = np.sqrt(1/3)
elif layer_attn[-lidx-1] and fix_gain_again: # x + residual + attention
gain = np.sqrt(1/3)
else: # x + residual
gain = np.sqrt(1/2) # Only residual block
block = dec_blk(cur_dim, dim_out, skip_gain=gain)
res_blocks.append(block)
if has_unet:
unet_block = Conv2d(
cur_dim, cur_dim, kernel_size=1, conv_clamp=conv_clamp,
norm=nn.InstanceNorm2d(None),
gradient_checkpoint_norm=gradient_checkpoint_norm,
gain=gain)
unet_skips.append(unet_block)
else:
unet_skips.append(torch.nn.Identity())
if layer_attn[-lidx-1]:
self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain))
else:
self.decoder_attns.append(Identity())
decoder_layers.append(res_blocks)
self.unet_layers.append(unet_skips)
middle_blocks = []
for i in range(n_middle_blocks):
block = dec_blk(dims[-1], dims[-1])
middle_blocks.append(block)
if n_middle_blocks != 0:
self.middle_blocks = Sequential(*middle_blocks)
self.decoder = torch.nn.ModuleList(decoder_layers)
self.to_rgb = Conv2d(dim, im_channels, 1, activation="linear", conv_clamp=conv_clamp)
self.stylenet = stylenet
self.downsample = Upfirdn2d(down=2, fix_gain=fix_resize)
self.upsample = Upfirdn2d(up=2, fix_gain=fix_resize)
self.comodulate = comodulate
if comodulate:
assert not self.enc_style
self.to_y = nn.Sequential(
Conv2d(dims[-1], dims[-1], lr_multiplier=lr_comod, gradient_checkpoint_norm=gradient_checkpoint_norm),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
FullyConnectedLayer(dims[-1], 512, activation="lrelu", lr_multiplier=lr_comod)
)
self.comod_net = comod_net
def forward(self, condition, mask, maskrcnn_mask=None, z=None, w=None, update_emas=False, keypoints=None, return_decoder_features=False, **kwargs):
if z is None:
z = self.get_z(condition)
if w is None:
w = self.stylenet(z, update_emas=update_emas)
if self.use_maskrcnn_mask:
x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
else:
x = torch.cat((condition, mask, 1-mask), dim=1)
if self.input_keypoints:
keypoints = keypoints[:, self.input_keypoint_indices]
one_hot_pose = spatial_embed_keypoints(keypoints, x)
x = torch.cat((x, one_hot_pose), dim=1)
x = self.from_rgb(x)
x, unet_features = self.forward_enc(x, mask, w)
x, decoder_features = self.forward_dec(x, mask, w, unet_features)
x = self.to_rgb(x)
unmasked = x
if self.mask_out_train:
x = mask * condition + (1-mask) * x
out = dict(img=x, unmasked=unmasked)
if return_decoder_features:
out["decoder_features"] = decoder_features
return out
def forward_enc(self, x, mask, w):
unet_features = []
for i, res_blocks in enumerate(self.encoder):
is_last = i == len(self.encoder) - 1
for block in res_blocks:
x = block(x, w=w)
unet_features.append(x)
x = self.encoder_attns[i](x, mask=mask)
if not is_last:
x = self.downsample(x)
if self.comodulate:
y = self.to_y(x)
y = torch.cat((w, y), dim=-1)
w = self.comod_net(y)
return x, unet_features
def forward_dec(self, x, mask, w, unet_features):
if hasattr(self, "middle_blocks"):
x = self.middle_blocks(x, w=w)
features = []
unet_features = iter(reversed(unet_features))
for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)):
is_last = i == len(self.decoder) - 1
for skip, block in zip(unet_skip, res_blocks):
skip_x = next(unet_features)
if not isinstance(skip, torch.nn.Identity):
skip_x = skip(skip_x)
x = x + skip_x
x = block(x, w=w)
x = self.decoder_attns[i](x, mask=mask)
features.append(x)
if not is_last:
x = self.upsample(x)
return x, features
def get_w(self, z, update_emas):
return self.stylenet(z, update_emas=update_emas)
@torch.no_grad()
def sample(self, truncation_value, **kwargs):
if truncation_value is None:
return self.forward(**kwargs)
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
w = self.get_w(self.get_z(kwargs["condition"]), False)
w = self.stylenet.w_avg.to(w.dtype).lerp(w, truncation_value)
return self.forward(**kwargs, w=w)
def update_w(self, *args, **kwargs):
self.style_net.update_w(*args, **kwargs)
@property
def style_net(self):
return self.stylenet
@torch.no_grad()
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
if truncation_value is None:
return self.forward(**kwargs)
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
w = self.get_w(self.get_z(kwargs["condition"]), False)
if w_indices is None:
w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
w_centers = self.style_net.w_centers[w_indices].to(w.device)
w = w_centers.to(w.dtype).lerp(w, truncation_value)
return self.forward(**kwargs, w=w)
def get_stem_unet_kwargs(cfg):
if "stem_cfg" in cfg.generator: # If the stem has another stem, recursively apply get_stem_unet_kwargs
return get_stem_unet_kwargs(cfg.generator.stem_cfg)
return dict(cfg.generator)
class GrowingUnet(BaseGenerator):
def __init__(
self,
coarse_stem_cfg: str, # This can be a coarse generator or None
sr_cfg: str, # Can be a previous progressive u-net, Unet or None
residual: bool,
new_dataset: bool, # The "new dataset" creates condition first -> resizes
**unet_kwargs):
kwargs = dict()
if coarse_stem_cfg is not None:
coarse_stem_cfg = utils.load_config(coarse_stem_cfg)
kwargs = get_stem_unet_kwargs(coarse_stem_cfg)
if sr_cfg is not None:
sr_cfg = utils.load_config(sr_cfg)
sr_stem_unet_kwargs = get_stem_unet_kwargs(sr_cfg)
kwargs.update(sr_stem_unet_kwargs)
kwargs.update(unet_kwargs)
kwargs["stylenet"] = None
kwargs.pop("_target_")
if "sr_cfg" in kwargs: # Unet kwargs are inherited, do not pass this to the new u-net
del kwargs["sr_cfg"]
if "coarse_stem_cfg" in kwargs:
del kwargs["coarse_stem_cfg"]
super().__init__(z_channels=kwargs["z_channels"])
if coarse_stem_cfg is not None:
z_channels = coarse_stem_cfg.generator.z_channels
super().__init__(z_channels)
self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval()
self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize)
utils.set_requires_grad(self.coarse_stem, False)
else:
assert not residual
if sr_cfg is not None:
self.sr_stem = infer.build_trained_generator(sr_cfg, map_location="cpu").eval()
del self.sr_stem.from_rgb
del self.sr_stem.to_rgb
if hasattr(self.sr_stem, "coarse_stem"):
del self.sr_stem.coarse_stem
if isinstance(self.sr_stem, UNet):
del self.sr_stem.encoder[0][0] # Delete first residual block
del self.sr_stem.decoder[-1][-1] # Delete last residual block
else:
assert isinstance(self.sr_stem, GrowingUnet)
del self.sr_stem.unet.encoder[0][0] # Delete first residual block
del self.sr_stem.unet.decoder[-1][-1] # Delete last residual block
utils.set_requires_grad(self.sr_stem, False)
args = kwargs.pop("_args_")
if hasattr(self, "sr_stem"): # Growing the SR stem - Add a new layer to match sr
n_layers = len(kwargs["dim_mults"])
dim_mult = sr_stem_unet_kwargs["dim"] / (kwargs["dim"] * max(kwargs["dim_mults"]))
kwargs["dim_mults"] = [*kwargs["dim_mults"], int(dim_mult)]
kwargs["layer_attn"] = [*cast_tuple(kwargs["layer_attn"], n_layers), False]
kwargs["num_resnet_blocks"] = [*cast_tuple(kwargs["num_resnet_blocks"], n_layers), 1]
self.unet = UNet(
*args,
**kwargs
)
self.from_rgb = self.unet.from_rgb
self.to_rgb = self.unet.to_rgb
self.residual = residual
self.new_dataset = new_dataset
if residual:
nn.init.zeros_(self.to_rgb.weight.data)
del self.unet.from_rgb, self.unet.to_rgb
def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, **kwargs):
# Downsample for stem
if z is None:
z = self.get_z(img)
if w is None:
w = self.style_net(z)
if hasattr(self, "coarse_stem"):
with torch.no_grad():
if self.new_dataset:
img_stem = utils.denormalize_img(img)*255
condition_stem = img_stem * mask + (1-mask)*127
condition_stem = condition_stem.round()
condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True)
condition_stem = condition_stem / 255 *2 - 1
mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float()
maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float()
else:
mask_stem = (resize(mask, self.coarse_stem.imsize, antialias=True) > .99).float()
maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, antialias=True) > .5).float()
img_stem = utils.denormalize_img(img)*255
img_stem = resize(img_stem, self.coarse_stem.imsize, antialias=True).round()
img_stem = img_stem / 255 * 2 - 1
condition_stem = img_stem * mask_stem
stem_out = self.coarse_stem(
condition=condition_stem, mask=mask_stem,
maskrcnn_mask=maskrcnn_stem, w=w,
keypoints=keypoints)
x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True)
condition = condition*mask + (1-mask) * x_lr
if self.unet.use_maskrcnn_mask:
x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
else:
x = torch.cat((condition, mask, 1-mask), dim=1)
if self.unet.input_keypoints:
keypoints = keypoints[:, self.unet.input_keypoint_indices]
one_hot_pose = spatial_embed_keypoints(keypoints, x)
x = torch.cat((x, one_hot_pose), dim=1)
x = self.from_rgb(x)
x, unet_features = self.forward_enc(x, mask, w)
x = self.forward_dec(x, mask, w, unet_features)
if self.residual:
x = self.to_rgb(x) + condition
else:
x = self.to_rgb(x)
return dict(
img=condition * mask + (1-mask) * x,
unmasked=x,
x_lowres=[condition]
)
def forward_enc(self, x, mask, w):
x, unet_features = self.unet.forward_enc(x, mask, w)
if hasattr(self, "sr_stem"):
x, unet_features_stem = self.sr_stem.forward_enc(x, mask, w)
else:
unet_features_stem = None
return x, [unet_features, unet_features_stem]
def forward_dec(self, x, mask, w, unet_features):
unet_features, unet_features_stem = unet_features
if hasattr(self, "sr_stem"):
x = self.sr_stem.forward_dec(x, mask, w, unet_features_stem)
x, unet_features = self.unet.forward_dec(x, mask, w, unet_features)
return x
def get_z(self, *args, **kwargs):
if hasattr(self, "coarse_stem"):
return self.coarse_stem.get_z(*args, **kwargs)
if hasattr(self, "sr_stem"):
return self.sr_stem.get_z(*args, **kwargs)
raise AttributeError()
@property
def style_net(self):
if hasattr(self, "coarse_stem"):
return self.coarse_stem.style_net
if hasattr(self, "sr_stem"):
return self.sr_stem.style_net
raise AttributeError()
def update_w(self, *args, **kwargs):
self.style_net.update_w(*args, **kwargs)
def get_w(self, z, update_emas):
return self.style_net(z, update_emas=update_emas)
@torch.no_grad()
def sample(self, truncation_value, **kwargs):
if truncation_value is None:
return self.forward(**kwargs)
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
w = self.get_w(self.get_z(kwargs["condition"]), False)
w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
return self.forward(**kwargs, w=w)
@torch.no_grad()
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
if truncation_value is None:
return self.forward(**kwargs)
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
w = self.get_w(self.get_z(kwargs["condition"]), False)
if w_indices is None:
w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
w_centers = self.style_net.w_centers[w_indices].to(w.device)
w = w_centers.to(w.dtype).lerp(w, truncation_value)
return self.forward(**kwargs, w=w)
class CascadedUnet(BaseGenerator):
def __init__(
self,
coarse_stem_cfg: str, # This can be a coarse generator or None
residual: bool,
new_dataset: bool, # The "new dataset" creates condition first -> resizes
imsize: tuple,
cascade:bool,
**unet_kwargs):
kwargs = dict()
coarse_stem_cfg = utils.load_config(coarse_stem_cfg)
kwargs = get_stem_unet_kwargs(coarse_stem_cfg)
kwargs.update(unet_kwargs)
super().__init__(z_channels=kwargs["z_channels"])
self.input_keypoints = kwargs["input_keypoints"]
self.input_keypoint_indices = kwargs["input_keypoint_indices"]
self.use_maskrcnn_mask = kwargs["use_maskrcnn_mask"]
self.imsize = imsize
self.residual = residual
self.new_dataset = new_dataset
# Setup coarse stem
stem_dims = [m*coarse_stem_cfg.generator.dim for m in coarse_stem_cfg.generator.dim_mults]
self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval()
self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize)
utils.set_requires_grad(self.coarse_stem, False)
self.stem_res_to_layer_idx = {
self.coarse_stem.imsize[0] // 2^i: stem_dims[i]
for i in range(len(stem_dims))
}
dim = kwargs["dim"]
dim_mults = kwargs["dim_mults"]
n_layers = len(dim_mults)
dims = [dim*s for s in dim_mults]
layer_attn = cast_tuple(kwargs["layer_attn"], n_layers)
num_resnet_blocks = cast_tuple(kwargs["num_resnet_blocks"], n_layers)
attn_cls = kwargs["attn_cls"]
if not isinstance(attn_cls, partial):
attn_cls = instantiate(attn_cls)
dec_blk = partial(
SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_dec"],
use_adain=kwargs["use_adain"] and kwargs["dec_style"],
w_dim=kwargs["w_dim"],
cross_attention=kwargs["cross_attention"],
cross_attention_len=kwargs["cross_attention_len"],
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"]
)
enc_blk = partial(
SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_enc"],
use_adain=kwargs["use_adain"] and kwargs["enc_style"],
w_dim=kwargs["w_dim"],
cross_attention=kwargs["cross_attention"],
cross_attention_len=kwargs["cross_attention_len"],
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"]
)
# Currently up/down sampling is done by bilinear upsampling.
# This can be simplified by replacing it with a strided upsampling layer...
self.encoder_attns = nn.ModuleList()
self.encoder_unet_skips = nn.ModuleDict()
self.encoder = nn.ModuleList()
for lidx in range(n_layers):
has_stem_feature = imsize[0]//2^lidx in self.stem_res_to_layer_idx and cascade
next_layer_has_stem_features = lidx+1 < n_layers and imsize[0]//2^(lidx+1) in self.stem_res_to_layer_idx and cascade
dim_in = dims[lidx]
dim_out = dims[min(lidx+1, n_layers-1)]
res_blocks = nn.ModuleList()
if has_stem_feature:
prev_layer_has_attention = lidx != 0 and layer_attn[lidx-1]
stem_lidx = self.stem_res_to_layer_idx[imsize[0]//2^lidx]
self.encoder_unet_skips.add_module(
str(imsize[0]//2^lidx),
Conv2d(
stem_dims[stem_lidx], dim_in, kernel_size=1,
conv_clamp=kwargs["conv_clamp"],
norm=nn.InstanceNorm2d(None),
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"],
gain=np.sqrt(1/4) if prev_layer_has_attention else np.sqrt(1/3) # This + previous residual + attention
)
)
for i in range(num_resnet_blocks[lidx]):
is_last = num_resnet_blocks[lidx] - 1 == i
cur_dim = dim_out if is_last else dim_in
if not is_last:
gain = np.sqrt(.5)
elif next_layer_has_stem_features and layer_attn[lidx]:
gain = np.sqrt(1/4)
elif layer_attn[lidx] or next_layer_has_stem_features:
gain = np.sqrt(1/3)
else:
gain = np.sqrt(.5)
block = enc_blk(dim_in, cur_dim, skip_gain=gain)
res_blocks.append(block)
if layer_attn[lidx]:
self.encoder_attns.append(attn_cls(dim=dim_out, gain=gain, fix_attention_again=True))
else:
self.encoder_attns.append(Identity())
self.encoder.append(res_blocks)
# initialize decoder
self.decoder = torch.nn.ModuleList()
self.unet_layers = torch.nn.ModuleList()
self.decoder_attns = torch.nn.ModuleList()
for lidx in range(n_layers):
dim_in = dims[min(-lidx, -1)]
dim_out = dims[-1-lidx]
res_blocks = nn.ModuleList()
unet_skips = nn.ModuleList()
for i in range(num_resnet_blocks[-lidx-1]):
is_first = i == 0
has_unet = is_first or kwargs["skip_all_unets"]
is_last = i == num_resnet_blocks[-lidx-1] - 1
cur_dim = dim_in if is_first else dim_out
if has_unet and is_last and layer_attn[-lidx-1]: # x + residual + unet + layer attn
gain = np.sqrt(1/4)
elif has_unet: # x + residual + unet
gain = np.sqrt(1/3)
elif layer_attn[-lidx-1]: # x + residual + attention
gain = np.sqrt(1/3)
else: # x + residual
gain = np.sqrt(1/2) # Only residual block
block = dec_blk(cur_dim, dim_out, skip_gain=gain)
res_blocks.append(block)
if kwargs["skip_all_unets"] or is_first:
unet_block = Conv2d(
cur_dim, cur_dim, kernel_size=1, conv_clamp=kwargs["conv_clamp"],
norm=nn.InstanceNorm2d(None),
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"],
gain=gain)
unet_skips.append(unet_block)
else:
unet_skips.append(torch.nn.Identity())
if layer_attn[-lidx-1]:
self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=True, gain=gain))
else:
self.decoder_attns.append(Identity())
self.decoder.append(res_blocks)
self.unet_layers.append(unet_skips)
self.from_rgb = Conv2d(
3 + 2 + 2*int(kwargs["use_maskrcnn_mask"]) + self.input_keypoints*len(kwargs["input_keypoint_indices"])
, dim, 7)
self.to_rgb = Conv2d(dim, 3, 1, activation="linear", conv_clamp=kwargs["conv_clamp"])
self.downsample = Upfirdn2d(down=2, fix_gain=True)
self.upsample = Upfirdn2d(up=2, fix_gain=True)
self.cascade = cascade
if residual:
nn.init.zeros_(self.to_rgb.weight.data)
def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, return_decoder_features=False, **kwargs):
# Downsample for stem
if z is None:
z = self.get_z(img)
with torch.no_grad(): # Forward pass stem
if w is None:
w = self.style_net(z)
img_stem = utils.denormalize_img(img)*255
condition_stem = img_stem * mask + (1-mask)*127
condition_stem = condition_stem.round()
condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True)
condition_stem = condition_stem / 255 *2 - 1
mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float()
maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float()
stem_out = self.coarse_stem(
condition=condition_stem, mask=mask_stem,
maskrcnn_mask=maskrcnn_stem, w=w,
keypoints=keypoints,
return_decoder_features=True)
stem_features = stem_out["decoder_features"]
x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True)
condition = condition*mask + (1-mask) * x_lr
if self.use_maskrcnn_mask:
x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1)
else:
x = torch.cat((condition, mask, 1-mask), dim=1)
if self.input_keypoints:
keypoints = keypoints[:, self.input_keypoint_indices]
one_hot_pose = spatial_embed_keypoints(keypoints, x)
x = torch.cat((x, one_hot_pose), dim=1)
x = self.from_rgb(x)
x, unet_features = self.forward_enc(x, mask, w, stem_features)
x, decoder_features = self.forward_dec(x, mask, w, unet_features)
if self.residual:
x = self.to_rgb(x) + condition
else:
x = self.to_rgb(x)
out= dict(
img=condition * mask + (1-mask) * x, # TODO: Probably do not want masked here... or ??
unmasked=x,
x_lowres=[condition]
)
if return_decoder_features:
out["decoder_features"] = decoder_features
return out
def forward_enc(self, x, mask, w, stem_features: List[torch.Tensor]):
unet_features = []
stem_features.reverse()
for i, res_blocks in enumerate(self.encoder):
is_last = i == len(self.encoder) - 1
res = self.imsize[0]//2^i
if str(res) in self.encoder_unet_skips.keys() and self.cascade:
y = stem_features[self.stem_res_to_layer_idx[res]]
y = self.encoder_unet_skips[i](y)
x = y + x
for block in res_blocks:
x = block(x, w=w)
unet_features.append(x)
x = self.encoder_attns[i](x, mask)
if not is_last:
x = self.downsample(x)
return x, unet_features
def forward_dec(self, x, mask, w, unet_features):
features = []
unet_features = iter(reversed(unet_features))
for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)):
is_last = i == len(self.decoder) - 1
for skip, block in zip(unet_skip, res_blocks):
skip_x = next(unet_features)
if not isinstance(skip, torch.nn.Identity):
skip_x = skip(skip_x)
x = x + skip_x
x = block(x, w=w)
x = self.decoder_attns[i](x, mask)
features.append(x)
if not is_last:
x = self.upsample(x)
return x, features
def get_z(self, *args, **kwargs):
return self.coarse_stem.get_z(*args, **kwargs)
@property
def style_net(self):
return self.coarse_stem.style_net
def update_w(self, *args, **kwargs):
self.style_net.update_w(*args, **kwargs)
def get_w(self, z, update_emas):
return self.style_net(z, update_emas=update_emas)
@torch.no_grad()
def sample(self, truncation_value, **kwargs):
if truncation_value is None:
return self.forward(**kwargs)
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
w = self.get_w(self.get_z(kwargs["condition"]), False)
w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value)
return self.forward(**kwargs, w=w)
@torch.no_grad()
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
if truncation_value is None:
return self.forward(**kwargs)
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
w = self.get_w(self.get_z(kwargs["condition"]), False)
if w_indices is None:
w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w)))
w_centers = self.style_net.w_centers[w_indices].to(w.device)
w = w_centers.to(w.dtype).lerp(w, truncation_value)
return self.forward(**kwargs, w=w)