Spaces:
Runtime error
Runtime error
# What is missing from this implementation | |
# 1. Global context in res block | |
# 2. Cross attention of conditional information in resnet block | |
# | |
from functools import partial | |
import tops | |
from tops.config import instantiate | |
import warnings | |
from typing import Iterable, List, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch import einsum | |
from einops import rearrange | |
from dp2 import infer, utils | |
from .base import BaseGenerator | |
from sg3_torch_utils.ops import bias_act | |
from dp2.layers import Sequential | |
import torch.nn.functional as F | |
from torchvision.transforms.functional import resize, InterpolationMode | |
from sg3_torch_utils.ops import conv2d_resample, fma, upfirdn2d | |
class Upfirdn2d(torch.nn.Module): | |
def __init__(self, down=1, up=1, fix_gain=True): | |
super().__init__() | |
self.register_buffer("resample_filter", upfirdn2d.setup_filter([1, 3, 3, 1])) | |
fw, fh = upfirdn2d._get_filter_size(self.resample_filter) | |
px0, px1, py0, py1 = upfirdn2d._parse_padding(0) | |
self.down = down | |
self.up = up | |
if up > 1: | |
px0 += (fw + up - 1) // 2 | |
px1 += (fw - up) // 2 | |
py0 += (fh + up - 1) // 2 | |
py1 += (fh - up) // 2 | |
if down > 1: | |
px0 += (fw - down + 1) // 2 | |
px1 += (fw - down) // 2 | |
py0 += (fh - down + 1) // 2 | |
py1 += (fh - down) // 2 | |
self.padding = [px0,px1,py0,py1] | |
self.gain = up**2 if fix_gain else 1 | |
def forward(self, x, *args): | |
if isinstance(x, dict): | |
x = {k: v for k, v in x.items()} | |
x["x"] = upfirdn2d.upfirdn2d(x["x"], self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain) | |
return x | |
x = upfirdn2d.upfirdn2d(x, self.resample_filter, down=self.down, padding=self.padding, up=self.up, gain=self.gain) | |
if len(args) == 0: | |
return x | |
return (x, *args) | |
def spatial_embed_keypoints(keypoints: torch.Tensor, x): | |
tops.assert_shape(keypoints, (None, None, 3)) | |
B, N_K, _ = keypoints.shape | |
H, W = x.shape[-2:] | |
keypoint_spatial = torch.zeros(keypoints.shape[0], N_K, H, W, device=keypoints.device, dtype=torch.float32) | |
x, y, visible = keypoints.chunk(3, dim=2) | |
x = (x * W).round().long().clamp(0, W-1) | |
y = (y * H).round().long().clamp(0, H-1) | |
kp_idx = torch.arange(0, N_K, 1, device=keypoints.device, dtype=torch.long).view(1, -1, 1).repeat(B, 1, 1) | |
pos = (kp_idx*(H*W) + y*W + x + 1) | |
# Offset all by 1 to index invisible keypoints as 0 | |
pos = (pos * visible.round().long()).squeeze(dim=-1) | |
keypoint_spatial = torch.zeros(keypoints.shape[0], N_K*H*W+1, device=keypoints.device, dtype=torch.float32) | |
keypoint_spatial.scatter_(1, pos, 1) | |
keypoint_spatial = keypoint_spatial[:, 1:].view(-1, N_K, H, W) | |
return keypoint_spatial | |
def modulated_conv2d( | |
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width]. | |
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width]. | |
styles, # Modulation coefficients of shape [batch_size, in_channels]. | |
noise = None, # Optional noise tensor to add to the output activations. | |
up = 1, # Integer upsampling factor. | |
down = 1, # Integer downsampling factor. | |
padding = 0, # Padding with respect to the upsampled image. | |
resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter(). | |
demodulate = True, # Apply weight demodulation? | |
flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d). | |
fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation? | |
): | |
batch_size = x.shape[0] | |
out_channels, in_channels, kh, kw = weight.shape | |
tops.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk] | |
tops.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW] | |
tops.assert_shape(styles, [batch_size, in_channels]) # [NI] | |
# Pre-normalize inputs to avoid FP16 overflow. | |
if x.dtype == torch.float16 and demodulate: | |
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk | |
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I | |
# Calculate per-sample weights and demodulation coefficients. | |
w = None | |
dcoefs = None | |
if demodulate or fused_modconv: | |
w = weight.unsqueeze(0) # [NOIkk] | |
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk] | |
if demodulate: | |
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] | |
if demodulate and fused_modconv: | |
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk] | |
# Execute by scaling the activations before and after the convolution. | |
if not fused_modconv: | |
x = x * styles.reshape(batch_size, -1, 1, 1) | |
x = conv2d_resample.conv2d_resample(x=x, w=weight, f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) | |
if demodulate and noise is not None: | |
x = fma.fma(x, dcoefs.reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) | |
elif demodulate: | |
x = x * dcoefs.reshape(batch_size, -1, 1, 1) | |
elif noise is not None: | |
x = x.add_(noise.to(x.dtype)) | |
return x | |
with tops.suppress_tracer_warnings(): # this value will be treated as a constant | |
batch_size = int(batch_size) | |
# Execute as one fused op using grouped convolution. | |
tops.assert_shape(x, [batch_size, in_channels, None, None]) | |
x = x.reshape(1, -1, *x.shape[2:]) | |
w = w.reshape(-1, in_channels, kh, kw) | |
x = conv2d_resample.conv2d_resample(x=x, w=w, f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight) | |
x = x.reshape(batch_size, -1, *x.shape[2:]) | |
if noise is not None: | |
x = x.add_(noise) | |
return x | |
class Identity(nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
def forward(self, x, *args, **kwargs): | |
return x | |
class LayerNorm(nn.Module): | |
def __init__(self, dim, stable=False): | |
super().__init__() | |
self.stable = stable | |
self.g = nn.Parameter(torch.ones(dim)) | |
def forward(self, x): | |
if self.stable: | |
x = x / x.amax(dim=-1, keepdim=True).detach() | |
eps = 1e-5 if x.dtype == torch.float32 else 1e-3 | |
var = torch.var(x, dim=-1, unbiased=False, keepdim=True) | |
mean = torch.mean(x, dim=-1, keepdim=True) | |
return (x - mean) * (var + eps).rsqrt() * self.g | |
class FullyConnectedLayer(torch.nn.Module): | |
def __init__(self, | |
in_features, # Number of input features. | |
out_features, # Number of output features. | |
bias = True, # Apply additive bias before the activation function? | |
activation = 'linear', # Activation function: 'relu', 'lrelu', etc. | |
lr_multiplier = 1, # Learning rate multiplier. | |
bias_init = 0, # Initial value for the additive bias. | |
): | |
super().__init__() | |
self.repr = dict( | |
in_features=in_features, out_features=out_features, bias=bias, | |
activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init) | |
self.activation = activation | |
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) | |
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None | |
self.weight_gain = lr_multiplier / np.sqrt(in_features) | |
self.bias_gain = lr_multiplier | |
self.in_features = in_features | |
self.out_features = out_features | |
def forward(self, x): | |
w = self.weight * self.weight_gain | |
b = self.bias | |
if b is not None: | |
if self.bias_gain != 1: | |
b = b * self.bias_gain | |
x = F.linear(x, w) | |
x = bias_act.bias_act(x, b, act=self.activation) | |
return x | |
def extra_repr(self) -> str: | |
return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) | |
def checkpoint_fn(fn, *args, **kwargs): | |
warnings.simplefilter("ignore") | |
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs) | |
class Conv2d(torch.nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
activation='lrelu', | |
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
bias=True, | |
norm=None, | |
lr_multiplier=1, | |
bias_init=0, | |
w_dim=None, | |
gradient_checkpoint_norm=False, | |
gain=1, | |
): | |
super().__init__() | |
self.fused_modconv = False | |
if norm == torch.nn.InstanceNorm2d: | |
self.norm = torch.nn.InstanceNorm2d(None) | |
elif isinstance(norm, torch.nn.Module): | |
self.norm = norm | |
elif norm == "fused_modconv": | |
self.fused_modconv = True | |
elif norm: | |
self.norm = torch.nn.InstanceNorm2d(None) | |
elif norm is not None: | |
raise ValueError(f"norm not supported: {norm}") | |
self.activation = activation | |
self.conv_clamp = conv_clamp | |
self.out_channels = out_channels | |
self.in_channels = in_channels | |
self.padding = kernel_size // 2 | |
self.repr = dict( | |
in_channels=in_channels, out_channels=out_channels, | |
kernel_size=kernel_size, | |
activation=activation, conv_clamp=conv_clamp, bias=bias, | |
fused_modconv=self.fused_modconv | |
) | |
self.act_gain = bias_act.activation_funcs[activation].def_gain * gain | |
self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2)) | |
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size])) | |
self.bias = torch.nn.Parameter(torch.zeros([out_channels])+bias_init) if bias else None | |
self.bias_gain = lr_multiplier | |
if w_dim is not None: | |
if self.fused_modconv: | |
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) | |
else: | |
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1) | |
self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0) | |
self.gradient_checkpoint_norm = gradient_checkpoint_norm | |
def forward(self, x, w=None, gain=1, **kwargs): | |
if self.fused_modconv: | |
styles = self.affine(w) | |
with torch.cuda.amp.autocast(enabled=False): | |
x = modulated_conv2d(x=x.half(), weight=self.weight.half(), styles=styles.half(), noise=None, | |
padding=self.padding, flip_weight=True, fused_modconv=False).to(x.dtype) | |
else: | |
if hasattr(self, "affine"): | |
gamma = self.affine(w).view(-1, self.in_channels, 1, 1) | |
beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1) | |
x = fma.fma(x, gamma ,beta) | |
w = self.weight * self.weight_gain | |
x = F.conv2d(input=x, weight=w, padding=self.padding,) | |
if hasattr(self, "norm"): | |
if self.gradient_checkpoint_norm: | |
x = checkpoint_fn(self.norm, x) | |
else: | |
x = self.norm(x) | |
act_gain = self.act_gain * gain | |
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None | |
b = self.bias * self.bias_gain if self.bias is not None else None | |
x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp) | |
return x | |
def extra_repr(self) -> str: | |
return ", ".join([f"{key}={item}" for key, item in self.repr.items()]) | |
class CrossAttention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
context_dim, | |
dim_head=64, | |
heads=8, | |
norm_context=False, | |
): | |
super().__init__() | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
inner_dim = dim_head * heads | |
self.norm = nn.InstanceNorm1d(dim) | |
self.norm_context = nn.InstanceNorm1d(None) if norm_context else Identity() | |
self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) | |
self.to_out = nn.Sequential( | |
nn.Linear(inner_dim, dim, bias=False), | |
nn.InstanceNorm1d(None) | |
) | |
def forward(self, x, w): | |
x = self.norm(x) | |
w = self.norm_context(w) | |
q, k, v = (self.to_q(x), *self.to_kv(w).chunk(2, dim = -1)) | |
q = rearrange(q, "b n (h d) -> b h n d", h = self.heads) | |
k = rearrange(k, "b n (h d) -> b h n d", h = self.heads) | |
v = rearrange(v, "b n (h d) -> b h n d", h = self.heads) | |
q = q * self.scale | |
# similarities | |
sim = einsum('b h i d, b h j d -> b h i j', q, k) | |
attn = sim.softmax(dim = -1, dtype = torch.float32) | |
out = einsum('b h i j, b h j d -> b h i d', attn, v) | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
return self.to_out(out) | |
class SG2ResidualBlock(torch.nn.Module): | |
def __init__( | |
self, | |
in_channels, # Number of input channels, 0 = first block. | |
out_channels, # Number of output channels. | |
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
skip_gain=np.sqrt(.5), | |
cross_attention: bool = False, | |
cross_attention_len: int = None, | |
use_adain: bool = True, | |
**layer_kwargs, # Arguments for conv layer. | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
w_dim = layer_kwargs.pop("w_dim") if "w_dim" in layer_kwargs else None | |
if use_adain: | |
layer_kwargs["w_dim"] = w_dim | |
self.conv0 = Conv2d(in_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs) | |
self.conv1 = Conv2d(out_channels, out_channels, conv_clamp=conv_clamp, **layer_kwargs, gain=skip_gain) | |
self.skip = Conv2d(in_channels, out_channels, kernel_size=1, bias=False, gain=skip_gain) | |
if cross_attention and w_dim is not None: | |
self.cross_attention_len = cross_attention_len | |
self.cross_attn = CrossAttention( | |
dim=out_channels, context_dim=w_dim//self.cross_attention_len, | |
gain=skip_gain) | |
def forward(self, x, w=None, **layer_kwargs): | |
y = self.skip(x) | |
x = self.conv0(x, w, **layer_kwargs) | |
x = self.conv1(x, w, **layer_kwargs) | |
if hasattr(self, "cross_attn"): | |
h = x.shape[-2] | |
x = rearrange(x, "b c h w -> b (h w) c") | |
w = rearrange(w, "b (n c) -> b n c", n=self.cross_attention_len) | |
x = self.cross_attn(x, w=w) + x | |
x = rearrange(x, "b (h w) c -> b c h w", h=h) | |
return y + x | |
def default(val, d): | |
if val is not None: | |
return val | |
return d() if callable(d) else d | |
def cast_tuple(val, length=None): | |
if isinstance(val, Iterable) and not isinstance(val, str): | |
val = tuple(val) | |
output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) | |
if length is not None: | |
assert len(output) == length, (output, length) | |
return output | |
class Attention(nn.Module): | |
# This is a version of Multi-Query Attention () | |
# Fast Transformer Decoding: One Write-Head is All You Need | |
# Ablated in: https://arxiv.org/pdf/2203.07814.pdf | |
# and https://arxiv.org/pdf/2204.02311.pdf | |
def __init__(self, dim, norm, attn_fix_gain, gradient_checkpoint, dim_head=64, heads=8, cosine_sim_attn=False, fix_attention_again=False, gain=None): | |
super().__init__() | |
self.scale = dim_head**-0.5 if not cosine_sim_attn else 1.0 | |
self.cosine_sim_attn = cosine_sim_attn | |
self.cosine_sim_scale = 16 if cosine_sim_attn else 1 | |
self.gradient_checkpoint = gradient_checkpoint | |
self.heads = heads | |
self.dim = dim | |
self.fix_attention_again = fix_attention_again | |
inner_dim = dim_head * heads | |
if norm == "LN": | |
self.norm = LayerNorm(dim) | |
elif norm == "IN": | |
self.norm = nn.InstanceNorm1d(dim) | |
elif norm is None: | |
self.norm = nn.Identity() | |
else: | |
raise ValueError(f"Norm not supported: {norm}") | |
self.to_q = FullyConnectedLayer(dim, inner_dim, bias=False) | |
self.to_kv = FullyConnectedLayer(dim, dim_head*2, bias=False) | |
self.to_out = nn.Sequential( | |
FullyConnectedLayer(inner_dim, dim, bias=False), | |
LayerNorm(dim) if norm == "LN" else nn.InstanceNorm1d(dim) | |
) | |
if fix_attention_again: | |
assert gain is not None | |
self.gain = gain | |
else: | |
self.gain = np.sqrt(.5) if attn_fix_gain else 1 | |
def run_function(self, x, attn_bias): | |
b, c, h, w = x.shape | |
x = rearrange(x, "b c h w -> b (h w) c") | |
in_ = x | |
b, n, device = *x.shape[:2], x.device | |
x = self.norm(x) | |
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) | |
q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) | |
q = q * self.scale | |
# calculate query / key similarities | |
sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale | |
if attn_bias is not None: | |
attn_bias = attn_bias | |
attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)") | |
sim = sim + attn_bias | |
attn = sim.softmax(dim=-1) | |
out = einsum("b h i j, b j d -> b h i d", attn, v) | |
out = rearrange(out, "b h n d -> b n (h d)") | |
if self.fix_attention_again: | |
out = self.to_out(out)*self.gain + in_ | |
else: | |
out = (self.to_out(out) + in_) * self.gain | |
out = rearrange(out, "b (h w) c -> b c h w", h=h) | |
return out | |
def forward(self, x, *args, attn_bias=None, **kwargs): | |
if self.gradient_checkpoint: | |
return checkpoint_fn(self.run_function, x, attn_bias) | |
return self.run_function(x, attn_bias) | |
def get_attention(self, x, attn_bias=None): | |
b, c, h, w = x.shape | |
x = rearrange(x, "b c h w -> b (h w) c") | |
in_ = x | |
b, n, device = *x.shape[:2], x.device | |
x = self.norm(x) | |
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) | |
q = rearrange(q, "b n (h d) -> b h n d", h=self.heads) | |
q = q * self.scale | |
# calculate query / key similarities | |
sim = einsum("b h i d, b j d -> b h i j", q, k) * self.cosine_sim_scale | |
if attn_bias is not None: | |
attn_bias = attn_bias | |
attn_bias = rearrange(attn_bias, "n c h w -> n c 1 (h w)") | |
sim = sim + attn_bias | |
attn = sim.softmax(dim=-1) | |
return attn, None | |
class BiasedAttention(Attention): | |
def __init__(self, *args, head_wise: bool=True, **kwargs): | |
super().__init__(*args, **kwargs) | |
out_ch = self.heads if head_wise else 1 | |
self.conv = Conv2d(self.dim+2, out_ch, activation="linear", kernel_size=3, bias_init=0) | |
nn.init.zeros_(self.conv.weight.data) | |
def forward(self, x, mask): | |
mask = resize(mask, size=x.shape[-2:]) | |
bias = self.conv(torch.cat((x, mask, 1-mask), dim=1)) | |
return super().forward(x=x, attn_bias=bias) | |
def get_attention(self, x, mask): | |
mask = resize(mask, size=x.shape[-2:]) | |
bias = self.conv(torch.cat((x, mask, 1-mask), dim=1)) | |
return super().get_attention(x, bias)[0], bias | |
class UNet(BaseGenerator): | |
def __init__( | |
self, | |
im_channels: int, | |
dim: int, | |
dim_mults: tuple, | |
num_resnet_blocks, # Number of resnet blocks per resolution | |
n_middle_blocks: int, | |
z_channels: int, | |
conv_clamp: int, | |
layer_attn, | |
w_dim: int, | |
norm_enc: bool, | |
norm_dec: str, | |
stylenet: nn.Module, | |
enc_style: bool, # Toggle style injection in encoder | |
use_maskrcnn_mask: bool, | |
skip_all_unets: bool, | |
fix_resize:bool, | |
comodulate: bool, | |
comod_net: nn.Module, | |
lr_comod: float, | |
dec_style: bool, | |
input_keypoints: bool, | |
n_keypoints: int, | |
input_keypoint_indices: Tuple[int], | |
use_adain: bool, | |
cross_attention: bool, | |
cross_attention_len: int, | |
gradient_checkpoint_norm: bool, | |
attn_cls: partial, | |
mask_out_train: bool, | |
fix_gain_again: bool, | |
) -> None: | |
super().__init__(z_channels) | |
self.enc_style = enc_style | |
self.n_keypoints = n_keypoints | |
self.input_keypoint_indices = list(input_keypoint_indices) | |
self.input_keypoints = input_keypoints | |
self.mask_out_train = mask_out_train | |
n_layers = len(dim_mults) | |
self.n_layers = n_layers | |
layer_attn = cast_tuple(layer_attn, n_layers) | |
num_resnet_blocks = cast_tuple(num_resnet_blocks, n_layers) | |
self._cnum = dim | |
self._image_channels = im_channels | |
self._z_channels = z_channels | |
encoder_layers = [] | |
condition_ch = im_channels | |
self.from_rgb = Conv2d( | |
condition_ch + 2 + 2*int(use_maskrcnn_mask) + self.input_keypoints*len(input_keypoint_indices) | |
, dim, 7) | |
self.use_maskrcnn_mask = use_maskrcnn_mask | |
self.skip_all_unets = skip_all_unets | |
dims = [dim*m for m in dim_mults] | |
enc_blk = partial( | |
SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_enc, | |
use_adain=use_adain and self.enc_style, | |
w_dim=w_dim, | |
cross_attention=cross_attention, | |
cross_attention_len=cross_attention_len, | |
gradient_checkpoint_norm=gradient_checkpoint_norm | |
) | |
dec_blk = partial( | |
SG2ResidualBlock, conv_clamp=conv_clamp, norm=norm_dec, | |
use_adain=use_adain and dec_style, | |
w_dim=w_dim, | |
cross_attention=cross_attention, | |
cross_attention_len=cross_attention_len, | |
gradient_checkpoint_norm=gradient_checkpoint_norm | |
) | |
# Currently up/down sampling is done by bilinear upsampling. | |
# This can be simplified by replacing it with a strided upsampling layer... | |
self.encoder_attns = nn.ModuleList() | |
for lidx in range(n_layers): | |
gain = np.sqrt(1/3) if layer_attn[lidx] and fix_gain_again else np.sqrt(.5) | |
dim_in = dims[lidx] | |
dim_out = dims[min(lidx+1, n_layers-1)] | |
res_blocks = nn.ModuleList() | |
for i in range(num_resnet_blocks[lidx]): | |
is_last = num_resnet_blocks[lidx] - 1 == i | |
cur_dim = dim_out if is_last else dim_in | |
block = enc_blk(dim_in, cur_dim, skip_gain=gain) | |
res_blocks.append(block) | |
if layer_attn[lidx]: | |
self.encoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain)) | |
else: | |
self.encoder_attns.append(Identity()) | |
encoder_layers.append(res_blocks) | |
self.encoder = torch.nn.ModuleList(encoder_layers) | |
# initialize decoder | |
decoder_layers = [] | |
self.unet_layers = torch.nn.ModuleList() | |
self.decoder_attns = torch.nn.ModuleList() | |
for lidx in range(n_layers): | |
dim_in = dims[min(-lidx, -1)] | |
dim_out = dims[-1-lidx] | |
res_blocks = nn.ModuleList() | |
unet_skips = nn.ModuleList() | |
for i in range(num_resnet_blocks[-lidx-1]): | |
is_first = i == 0 | |
has_unet = is_first or skip_all_unets | |
is_last = i == num_resnet_blocks[-lidx-1] - 1 | |
cur_dim = dim_in if is_first else dim_out | |
if has_unet and is_last and layer_attn[-lidx-1] and fix_gain_again: # x + residual + unet + layer attn | |
gain = np.sqrt(1/4) | |
elif has_unet: # x + residual + unet | |
gain = np.sqrt(1/3) | |
elif layer_attn[-lidx-1] and fix_gain_again: # x + residual + attention | |
gain = np.sqrt(1/3) | |
else: # x + residual | |
gain = np.sqrt(1/2) # Only residual block | |
block = dec_blk(cur_dim, dim_out, skip_gain=gain) | |
res_blocks.append(block) | |
if has_unet: | |
unet_block = Conv2d( | |
cur_dim, cur_dim, kernel_size=1, conv_clamp=conv_clamp, | |
norm=nn.InstanceNorm2d(None), | |
gradient_checkpoint_norm=gradient_checkpoint_norm, | |
gain=gain) | |
unet_skips.append(unet_block) | |
else: | |
unet_skips.append(torch.nn.Identity()) | |
if layer_attn[-lidx-1]: | |
self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=fix_gain_again, gain=gain)) | |
else: | |
self.decoder_attns.append(Identity()) | |
decoder_layers.append(res_blocks) | |
self.unet_layers.append(unet_skips) | |
middle_blocks = [] | |
for i in range(n_middle_blocks): | |
block = dec_blk(dims[-1], dims[-1]) | |
middle_blocks.append(block) | |
if n_middle_blocks != 0: | |
self.middle_blocks = Sequential(*middle_blocks) | |
self.decoder = torch.nn.ModuleList(decoder_layers) | |
self.to_rgb = Conv2d(dim, im_channels, 1, activation="linear", conv_clamp=conv_clamp) | |
self.stylenet = stylenet | |
self.downsample = Upfirdn2d(down=2, fix_gain=fix_resize) | |
self.upsample = Upfirdn2d(up=2, fix_gain=fix_resize) | |
self.comodulate = comodulate | |
if comodulate: | |
assert not self.enc_style | |
self.to_y = nn.Sequential( | |
Conv2d(dims[-1], dims[-1], lr_multiplier=lr_comod, gradient_checkpoint_norm=gradient_checkpoint_norm), | |
nn.AdaptiveAvgPool2d(1), | |
nn.Flatten(), | |
FullyConnectedLayer(dims[-1], 512, activation="lrelu", lr_multiplier=lr_comod) | |
) | |
self.comod_net = comod_net | |
def forward(self, condition, mask, maskrcnn_mask=None, z=None, w=None, update_emas=False, keypoints=None, return_decoder_features=False, **kwargs): | |
if z is None: | |
z = self.get_z(condition) | |
if w is None: | |
w = self.stylenet(z, update_emas=update_emas) | |
if self.use_maskrcnn_mask: | |
x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) | |
else: | |
x = torch.cat((condition, mask, 1-mask), dim=1) | |
if self.input_keypoints: | |
keypoints = keypoints[:, self.input_keypoint_indices] | |
one_hot_pose = spatial_embed_keypoints(keypoints, x) | |
x = torch.cat((x, one_hot_pose), dim=1) | |
x = self.from_rgb(x) | |
x, unet_features = self.forward_enc(x, mask, w) | |
x, decoder_features = self.forward_dec(x, mask, w, unet_features) | |
x = self.to_rgb(x) | |
unmasked = x | |
if self.mask_out_train: | |
x = mask * condition + (1-mask) * x | |
out = dict(img=x, unmasked=unmasked) | |
if return_decoder_features: | |
out["decoder_features"] = decoder_features | |
return out | |
def forward_enc(self, x, mask, w): | |
unet_features = [] | |
for i, res_blocks in enumerate(self.encoder): | |
is_last = i == len(self.encoder) - 1 | |
for block in res_blocks: | |
x = block(x, w=w) | |
unet_features.append(x) | |
x = self.encoder_attns[i](x, mask=mask) | |
if not is_last: | |
x = self.downsample(x) | |
if self.comodulate: | |
y = self.to_y(x) | |
y = torch.cat((w, y), dim=-1) | |
w = self.comod_net(y) | |
return x, unet_features | |
def forward_dec(self, x, mask, w, unet_features): | |
if hasattr(self, "middle_blocks"): | |
x = self.middle_blocks(x, w=w) | |
features = [] | |
unet_features = iter(reversed(unet_features)) | |
for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)): | |
is_last = i == len(self.decoder) - 1 | |
for skip, block in zip(unet_skip, res_blocks): | |
skip_x = next(unet_features) | |
if not isinstance(skip, torch.nn.Identity): | |
skip_x = skip(skip_x) | |
x = x + skip_x | |
x = block(x, w=w) | |
x = self.decoder_attns[i](x, mask=mask) | |
features.append(x) | |
if not is_last: | |
x = self.upsample(x) | |
return x, features | |
def get_w(self, z, update_emas): | |
return self.stylenet(z, update_emas=update_emas) | |
def sample(self, truncation_value, **kwargs): | |
if truncation_value is None: | |
return self.forward(**kwargs) | |
truncation_value = max(0, truncation_value) | |
truncation_value = min(truncation_value, 1) | |
w = self.get_w(self.get_z(kwargs["condition"]), False) | |
w = self.stylenet.w_avg.to(w.dtype).lerp(w, truncation_value) | |
return self.forward(**kwargs, w=w) | |
def update_w(self, *args, **kwargs): | |
self.style_net.update_w(*args, **kwargs) | |
def style_net(self): | |
return self.stylenet | |
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): | |
if truncation_value is None: | |
return self.forward(**kwargs) | |
truncation_value = max(0, truncation_value) | |
truncation_value = min(truncation_value, 1) | |
w = self.get_w(self.get_z(kwargs["condition"]), False) | |
if w_indices is None: | |
w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) | |
w_centers = self.style_net.w_centers[w_indices].to(w.device) | |
w = w_centers.to(w.dtype).lerp(w, truncation_value) | |
return self.forward(**kwargs, w=w) | |
def get_stem_unet_kwargs(cfg): | |
if "stem_cfg" in cfg.generator: # If the stem has another stem, recursively apply get_stem_unet_kwargs | |
return get_stem_unet_kwargs(cfg.generator.stem_cfg) | |
return dict(cfg.generator) | |
class GrowingUnet(BaseGenerator): | |
def __init__( | |
self, | |
coarse_stem_cfg: str, # This can be a coarse generator or None | |
sr_cfg: str, # Can be a previous progressive u-net, Unet or None | |
residual: bool, | |
new_dataset: bool, # The "new dataset" creates condition first -> resizes | |
**unet_kwargs): | |
kwargs = dict() | |
if coarse_stem_cfg is not None: | |
coarse_stem_cfg = utils.load_config(coarse_stem_cfg) | |
kwargs = get_stem_unet_kwargs(coarse_stem_cfg) | |
if sr_cfg is not None: | |
sr_cfg = utils.load_config(sr_cfg) | |
sr_stem_unet_kwargs = get_stem_unet_kwargs(sr_cfg) | |
kwargs.update(sr_stem_unet_kwargs) | |
kwargs.update(unet_kwargs) | |
kwargs["stylenet"] = None | |
kwargs.pop("_target_") | |
if "sr_cfg" in kwargs: # Unet kwargs are inherited, do not pass this to the new u-net | |
del kwargs["sr_cfg"] | |
if "coarse_stem_cfg" in kwargs: | |
del kwargs["coarse_stem_cfg"] | |
super().__init__(z_channels=kwargs["z_channels"]) | |
if coarse_stem_cfg is not None: | |
z_channels = coarse_stem_cfg.generator.z_channels | |
super().__init__(z_channels) | |
self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval() | |
self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize) | |
utils.set_requires_grad(self.coarse_stem, False) | |
else: | |
assert not residual | |
if sr_cfg is not None: | |
self.sr_stem = infer.build_trained_generator(sr_cfg, map_location="cpu").eval() | |
del self.sr_stem.from_rgb | |
del self.sr_stem.to_rgb | |
if hasattr(self.sr_stem, "coarse_stem"): | |
del self.sr_stem.coarse_stem | |
if isinstance(self.sr_stem, UNet): | |
del self.sr_stem.encoder[0][0] # Delete first residual block | |
del self.sr_stem.decoder[-1][-1] # Delete last residual block | |
else: | |
assert isinstance(self.sr_stem, GrowingUnet) | |
del self.sr_stem.unet.encoder[0][0] # Delete first residual block | |
del self.sr_stem.unet.decoder[-1][-1] # Delete last residual block | |
utils.set_requires_grad(self.sr_stem, False) | |
args = kwargs.pop("_args_") | |
if hasattr(self, "sr_stem"): # Growing the SR stem - Add a new layer to match sr | |
n_layers = len(kwargs["dim_mults"]) | |
dim_mult = sr_stem_unet_kwargs["dim"] / (kwargs["dim"] * max(kwargs["dim_mults"])) | |
kwargs["dim_mults"] = [*kwargs["dim_mults"], int(dim_mult)] | |
kwargs["layer_attn"] = [*cast_tuple(kwargs["layer_attn"], n_layers), False] | |
kwargs["num_resnet_blocks"] = [*cast_tuple(kwargs["num_resnet_blocks"], n_layers), 1] | |
self.unet = UNet( | |
*args, | |
**kwargs | |
) | |
self.from_rgb = self.unet.from_rgb | |
self.to_rgb = self.unet.to_rgb | |
self.residual = residual | |
self.new_dataset = new_dataset | |
if residual: | |
nn.init.zeros_(self.to_rgb.weight.data) | |
del self.unet.from_rgb, self.unet.to_rgb | |
def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, **kwargs): | |
# Downsample for stem | |
if z is None: | |
z = self.get_z(img) | |
if w is None: | |
w = self.style_net(z) | |
if hasattr(self, "coarse_stem"): | |
with torch.no_grad(): | |
if self.new_dataset: | |
img_stem = utils.denormalize_img(img)*255 | |
condition_stem = img_stem * mask + (1-mask)*127 | |
condition_stem = condition_stem.round() | |
condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True) | |
condition_stem = condition_stem / 255 *2 - 1 | |
mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float() | |
maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float() | |
else: | |
mask_stem = (resize(mask, self.coarse_stem.imsize, antialias=True) > .99).float() | |
maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, antialias=True) > .5).float() | |
img_stem = utils.denormalize_img(img)*255 | |
img_stem = resize(img_stem, self.coarse_stem.imsize, antialias=True).round() | |
img_stem = img_stem / 255 * 2 - 1 | |
condition_stem = img_stem * mask_stem | |
stem_out = self.coarse_stem( | |
condition=condition_stem, mask=mask_stem, | |
maskrcnn_mask=maskrcnn_stem, w=w, | |
keypoints=keypoints) | |
x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True) | |
condition = condition*mask + (1-mask) * x_lr | |
if self.unet.use_maskrcnn_mask: | |
x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) | |
else: | |
x = torch.cat((condition, mask, 1-mask), dim=1) | |
if self.unet.input_keypoints: | |
keypoints = keypoints[:, self.unet.input_keypoint_indices] | |
one_hot_pose = spatial_embed_keypoints(keypoints, x) | |
x = torch.cat((x, one_hot_pose), dim=1) | |
x = self.from_rgb(x) | |
x, unet_features = self.forward_enc(x, mask, w) | |
x = self.forward_dec(x, mask, w, unet_features) | |
if self.residual: | |
x = self.to_rgb(x) + condition | |
else: | |
x = self.to_rgb(x) | |
return dict( | |
img=condition * mask + (1-mask) * x, | |
unmasked=x, | |
x_lowres=[condition] | |
) | |
def forward_enc(self, x, mask, w): | |
x, unet_features = self.unet.forward_enc(x, mask, w) | |
if hasattr(self, "sr_stem"): | |
x, unet_features_stem = self.sr_stem.forward_enc(x, mask, w) | |
else: | |
unet_features_stem = None | |
return x, [unet_features, unet_features_stem] | |
def forward_dec(self, x, mask, w, unet_features): | |
unet_features, unet_features_stem = unet_features | |
if hasattr(self, "sr_stem"): | |
x = self.sr_stem.forward_dec(x, mask, w, unet_features_stem) | |
x, unet_features = self.unet.forward_dec(x, mask, w, unet_features) | |
return x | |
def get_z(self, *args, **kwargs): | |
if hasattr(self, "coarse_stem"): | |
return self.coarse_stem.get_z(*args, **kwargs) | |
if hasattr(self, "sr_stem"): | |
return self.sr_stem.get_z(*args, **kwargs) | |
raise AttributeError() | |
def style_net(self): | |
if hasattr(self, "coarse_stem"): | |
return self.coarse_stem.style_net | |
if hasattr(self, "sr_stem"): | |
return self.sr_stem.style_net | |
raise AttributeError() | |
def update_w(self, *args, **kwargs): | |
self.style_net.update_w(*args, **kwargs) | |
def get_w(self, z, update_emas): | |
return self.style_net(z, update_emas=update_emas) | |
def sample(self, truncation_value, **kwargs): | |
if truncation_value is None: | |
return self.forward(**kwargs) | |
truncation_value = max(0, truncation_value) | |
truncation_value = min(truncation_value, 1) | |
w = self.get_w(self.get_z(kwargs["condition"]), False) | |
w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value) | |
return self.forward(**kwargs, w=w) | |
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): | |
if truncation_value is None: | |
return self.forward(**kwargs) | |
truncation_value = max(0, truncation_value) | |
truncation_value = min(truncation_value, 1) | |
w = self.get_w(self.get_z(kwargs["condition"]), False) | |
if w_indices is None: | |
w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) | |
w_centers = self.style_net.w_centers[w_indices].to(w.device) | |
w = w_centers.to(w.dtype).lerp(w, truncation_value) | |
return self.forward(**kwargs, w=w) | |
class CascadedUnet(BaseGenerator): | |
def __init__( | |
self, | |
coarse_stem_cfg: str, # This can be a coarse generator or None | |
residual: bool, | |
new_dataset: bool, # The "new dataset" creates condition first -> resizes | |
imsize: tuple, | |
cascade:bool, | |
**unet_kwargs): | |
kwargs = dict() | |
coarse_stem_cfg = utils.load_config(coarse_stem_cfg) | |
kwargs = get_stem_unet_kwargs(coarse_stem_cfg) | |
kwargs.update(unet_kwargs) | |
super().__init__(z_channels=kwargs["z_channels"]) | |
self.input_keypoints = kwargs["input_keypoints"] | |
self.input_keypoint_indices = kwargs["input_keypoint_indices"] | |
self.use_maskrcnn_mask = kwargs["use_maskrcnn_mask"] | |
self.imsize = imsize | |
self.residual = residual | |
self.new_dataset = new_dataset | |
# Setup coarse stem | |
stem_dims = [m*coarse_stem_cfg.generator.dim for m in coarse_stem_cfg.generator.dim_mults] | |
self.coarse_stem = infer.build_trained_generator(coarse_stem_cfg, map_location="cpu").eval() | |
self.coarse_stem.imsize = tuple(coarse_stem_cfg.data.imsize) | |
utils.set_requires_grad(self.coarse_stem, False) | |
self.stem_res_to_layer_idx = { | |
self.coarse_stem.imsize[0] // 2^i: stem_dims[i] | |
for i in range(len(stem_dims)) | |
} | |
dim = kwargs["dim"] | |
dim_mults = kwargs["dim_mults"] | |
n_layers = len(dim_mults) | |
dims = [dim*s for s in dim_mults] | |
layer_attn = cast_tuple(kwargs["layer_attn"], n_layers) | |
num_resnet_blocks = cast_tuple(kwargs["num_resnet_blocks"], n_layers) | |
attn_cls = kwargs["attn_cls"] | |
if not isinstance(attn_cls, partial): | |
attn_cls = instantiate(attn_cls) | |
dec_blk = partial( | |
SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_dec"], | |
use_adain=kwargs["use_adain"] and kwargs["dec_style"], | |
w_dim=kwargs["w_dim"], | |
cross_attention=kwargs["cross_attention"], | |
cross_attention_len=kwargs["cross_attention_len"], | |
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"] | |
) | |
enc_blk = partial( | |
SG2ResidualBlock, conv_clamp=kwargs["conv_clamp"], norm=kwargs["norm_enc"], | |
use_adain=kwargs["use_adain"] and kwargs["enc_style"], | |
w_dim=kwargs["w_dim"], | |
cross_attention=kwargs["cross_attention"], | |
cross_attention_len=kwargs["cross_attention_len"], | |
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"] | |
) | |
# Currently up/down sampling is done by bilinear upsampling. | |
# This can be simplified by replacing it with a strided upsampling layer... | |
self.encoder_attns = nn.ModuleList() | |
self.encoder_unet_skips = nn.ModuleDict() | |
self.encoder = nn.ModuleList() | |
for lidx in range(n_layers): | |
has_stem_feature = imsize[0]//2^lidx in self.stem_res_to_layer_idx and cascade | |
next_layer_has_stem_features = lidx+1 < n_layers and imsize[0]//2^(lidx+1) in self.stem_res_to_layer_idx and cascade | |
dim_in = dims[lidx] | |
dim_out = dims[min(lidx+1, n_layers-1)] | |
res_blocks = nn.ModuleList() | |
if has_stem_feature: | |
prev_layer_has_attention = lidx != 0 and layer_attn[lidx-1] | |
stem_lidx = self.stem_res_to_layer_idx[imsize[0]//2^lidx] | |
self.encoder_unet_skips.add_module( | |
str(imsize[0]//2^lidx), | |
Conv2d( | |
stem_dims[stem_lidx], dim_in, kernel_size=1, | |
conv_clamp=kwargs["conv_clamp"], | |
norm=nn.InstanceNorm2d(None), | |
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"], | |
gain=np.sqrt(1/4) if prev_layer_has_attention else np.sqrt(1/3) # This + previous residual + attention | |
) | |
) | |
for i in range(num_resnet_blocks[lidx]): | |
is_last = num_resnet_blocks[lidx] - 1 == i | |
cur_dim = dim_out if is_last else dim_in | |
if not is_last: | |
gain = np.sqrt(.5) | |
elif next_layer_has_stem_features and layer_attn[lidx]: | |
gain = np.sqrt(1/4) | |
elif layer_attn[lidx] or next_layer_has_stem_features: | |
gain = np.sqrt(1/3) | |
else: | |
gain = np.sqrt(.5) | |
block = enc_blk(dim_in, cur_dim, skip_gain=gain) | |
res_blocks.append(block) | |
if layer_attn[lidx]: | |
self.encoder_attns.append(attn_cls(dim=dim_out, gain=gain, fix_attention_again=True)) | |
else: | |
self.encoder_attns.append(Identity()) | |
self.encoder.append(res_blocks) | |
# initialize decoder | |
self.decoder = torch.nn.ModuleList() | |
self.unet_layers = torch.nn.ModuleList() | |
self.decoder_attns = torch.nn.ModuleList() | |
for lidx in range(n_layers): | |
dim_in = dims[min(-lidx, -1)] | |
dim_out = dims[-1-lidx] | |
res_blocks = nn.ModuleList() | |
unet_skips = nn.ModuleList() | |
for i in range(num_resnet_blocks[-lidx-1]): | |
is_first = i == 0 | |
has_unet = is_first or kwargs["skip_all_unets"] | |
is_last = i == num_resnet_blocks[-lidx-1] - 1 | |
cur_dim = dim_in if is_first else dim_out | |
if has_unet and is_last and layer_attn[-lidx-1]: # x + residual + unet + layer attn | |
gain = np.sqrt(1/4) | |
elif has_unet: # x + residual + unet | |
gain = np.sqrt(1/3) | |
elif layer_attn[-lidx-1]: # x + residual + attention | |
gain = np.sqrt(1/3) | |
else: # x + residual | |
gain = np.sqrt(1/2) # Only residual block | |
block = dec_blk(cur_dim, dim_out, skip_gain=gain) | |
res_blocks.append(block) | |
if kwargs["skip_all_unets"] or is_first: | |
unet_block = Conv2d( | |
cur_dim, cur_dim, kernel_size=1, conv_clamp=kwargs["conv_clamp"], | |
norm=nn.InstanceNorm2d(None), | |
gradient_checkpoint_norm=kwargs["gradient_checkpoint_norm"], | |
gain=gain) | |
unet_skips.append(unet_block) | |
else: | |
unet_skips.append(torch.nn.Identity()) | |
if layer_attn[-lidx-1]: | |
self.decoder_attns.append(attn_cls(dim=dim_out, fix_attention_again=True, gain=gain)) | |
else: | |
self.decoder_attns.append(Identity()) | |
self.decoder.append(res_blocks) | |
self.unet_layers.append(unet_skips) | |
self.from_rgb = Conv2d( | |
3 + 2 + 2*int(kwargs["use_maskrcnn_mask"]) + self.input_keypoints*len(kwargs["input_keypoint_indices"]) | |
, dim, 7) | |
self.to_rgb = Conv2d(dim, 3, 1, activation="linear", conv_clamp=kwargs["conv_clamp"]) | |
self.downsample = Upfirdn2d(down=2, fix_gain=True) | |
self.upsample = Upfirdn2d(up=2, fix_gain=True) | |
self.cascade = cascade | |
if residual: | |
nn.init.zeros_(self.to_rgb.weight.data) | |
def forward(self, condition, img, mask, maskrcnn_mask=None, z=None, w=None, keypoints=None, return_decoder_features=False, **kwargs): | |
# Downsample for stem | |
if z is None: | |
z = self.get_z(img) | |
with torch.no_grad(): # Forward pass stem | |
if w is None: | |
w = self.style_net(z) | |
img_stem = utils.denormalize_img(img)*255 | |
condition_stem = img_stem * mask + (1-mask)*127 | |
condition_stem = condition_stem.round() | |
condition_stem = resize(condition_stem, self.coarse_stem.imsize, antialias=True) | |
condition_stem = condition_stem / 255 *2 - 1 | |
mask_stem = (torch.nn.functional.adaptive_max_pool2d(1 - mask, output_size=self.coarse_stem.imsize) > 0).logical_not().float() | |
maskrcnn_stem = (resize(maskrcnn_mask, self.coarse_stem.imsize, interpolation=InterpolationMode.NEAREST) > 0).float() | |
stem_out = self.coarse_stem( | |
condition=condition_stem, mask=mask_stem, | |
maskrcnn_mask=maskrcnn_stem, w=w, | |
keypoints=keypoints, | |
return_decoder_features=True) | |
stem_features = stem_out["decoder_features"] | |
x_lr = resize(stem_out["img"], condition.shape[-2:], antialias=True) | |
condition = condition*mask + (1-mask) * x_lr | |
if self.use_maskrcnn_mask: | |
x = torch.cat((condition, mask, 1-mask, maskrcnn_mask, 1-maskrcnn_mask), dim=1) | |
else: | |
x = torch.cat((condition, mask, 1-mask), dim=1) | |
if self.input_keypoints: | |
keypoints = keypoints[:, self.input_keypoint_indices] | |
one_hot_pose = spatial_embed_keypoints(keypoints, x) | |
x = torch.cat((x, one_hot_pose), dim=1) | |
x = self.from_rgb(x) | |
x, unet_features = self.forward_enc(x, mask, w, stem_features) | |
x, decoder_features = self.forward_dec(x, mask, w, unet_features) | |
if self.residual: | |
x = self.to_rgb(x) + condition | |
else: | |
x = self.to_rgb(x) | |
out= dict( | |
img=condition * mask + (1-mask) * x, # TODO: Probably do not want masked here... or ?? | |
unmasked=x, | |
x_lowres=[condition] | |
) | |
if return_decoder_features: | |
out["decoder_features"] = decoder_features | |
return out | |
def forward_enc(self, x, mask, w, stem_features: List[torch.Tensor]): | |
unet_features = [] | |
stem_features.reverse() | |
for i, res_blocks in enumerate(self.encoder): | |
is_last = i == len(self.encoder) - 1 | |
res = self.imsize[0]//2^i | |
if str(res) in self.encoder_unet_skips.keys() and self.cascade: | |
y = stem_features[self.stem_res_to_layer_idx[res]] | |
y = self.encoder_unet_skips[i](y) | |
x = y + x | |
for block in res_blocks: | |
x = block(x, w=w) | |
unet_features.append(x) | |
x = self.encoder_attns[i](x, mask) | |
if not is_last: | |
x = self.downsample(x) | |
return x, unet_features | |
def forward_dec(self, x, mask, w, unet_features): | |
features = [] | |
unet_features = iter(reversed(unet_features)) | |
for i, (unet_skip, res_blocks) in enumerate(zip(self.unet_layers, self.decoder)): | |
is_last = i == len(self.decoder) - 1 | |
for skip, block in zip(unet_skip, res_blocks): | |
skip_x = next(unet_features) | |
if not isinstance(skip, torch.nn.Identity): | |
skip_x = skip(skip_x) | |
x = x + skip_x | |
x = block(x, w=w) | |
x = self.decoder_attns[i](x, mask) | |
features.append(x) | |
if not is_last: | |
x = self.upsample(x) | |
return x, features | |
def get_z(self, *args, **kwargs): | |
return self.coarse_stem.get_z(*args, **kwargs) | |
def style_net(self): | |
return self.coarse_stem.style_net | |
def update_w(self, *args, **kwargs): | |
self.style_net.update_w(*args, **kwargs) | |
def get_w(self, z, update_emas): | |
return self.style_net(z, update_emas=update_emas) | |
def sample(self, truncation_value, **kwargs): | |
if truncation_value is None: | |
return self.forward(**kwargs) | |
truncation_value = max(0, truncation_value) | |
truncation_value = min(truncation_value, 1) | |
w = self.get_w(self.get_z(kwargs["condition"]), False) | |
w = self.style_net.w_avg.to(w.dtype).lerp(w, truncation_value) | |
return self.forward(**kwargs, w=w) | |
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): | |
if truncation_value is None: | |
return self.forward(**kwargs) | |
truncation_value = max(0, truncation_value) | |
truncation_value = min(truncation_value, 1) | |
w = self.get_w(self.get_z(kwargs["condition"]), False) | |
if w_indices is None: | |
w_indices = np.random.randint(0, len(self.style_net.w_centers), size=(len(w))) | |
w_centers = self.style_net.w_centers[w_indices].to(w.device) | |
w = w_centers.to(w.dtype).lerp(w, truncation_value) | |
return self.forward(**kwargs, w=w) | |