Spaces:
Runtime error
Runtime error
File size: 10,416 Bytes
97a6728 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
from typing import List
import numpy as np
import torch
import tops
import torch.nn.functional as F
from sg3_torch_utils.ops import conv2d_resample
from sg3_torch_utils.ops import upfirdn2d
from sg3_torch_utils.ops import bias_act
from sg3_torch_utils.ops.fma import fma
class FullyConnectedLayer(torch.nn.Module):
def __init__(self,
in_features, # Number of input features.
out_features, # Number of output features.
bias=True, # Apply additive bias before the activation function?
activation='linear', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier=1, # Learning rate multiplier.
bias_init=0, # Initial value for the additive bias.
):
super().__init__()
self.repr = dict(
in_features=in_features, out_features=out_features, bias=bias,
activation=activation, lr_multiplier=lr_multiplier, bias_init=bias_init)
self.activation = activation
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
self.weight_gain = lr_multiplier / np.sqrt(in_features)
self.bias_gain = lr_multiplier
self.in_features = in_features
self.out_features = out_features
def forward(self, x):
w = self.weight * self.weight_gain
b = self.bias
if b is not None and self.bias_gain != 1:
b = b * self.bias_gain
x = F.linear(x, w)
x = bias_act.bias_act(x, b, act=self.activation)
return x
def extra_repr(self) -> str:
return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
class Conv2d(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
out_channels, # Number of output channels.
kernel_size=3, # Convolution kernel size.
up=1, # Integer upsampling factor.
down=1, # Integer downsampling factor
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
bias=True,
norm=False,
lr_multiplier=1,
bias_init=0,
w_dim=None,
gain=1,
):
super().__init__()
if norm:
self.norm = torch.nn.InstanceNorm2d(None)
assert norm in [True, False]
self.up = up
self.down = down
self.activation = activation
self.conv_clamp = conv_clamp if conv_clamp is None else conv_clamp * gain
self.out_channels = out_channels
self.in_channels = in_channels
self.padding = kernel_size // 2
self.repr = dict(
in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, up=up, down=down,
activation=activation, resample_filter=resample_filter, conv_clamp=conv_clamp, bias=bias,
)
if self.up == 1 and self.down == 1:
self.resample_filter = None
else:
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.act_gain = bias_act.activation_funcs[activation].def_gain * gain
self.weight_gain = lr_multiplier / np.sqrt(in_channels * (kernel_size ** 2))
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]))
self.bias = torch.nn.Parameter(torch.zeros([out_channels]) + bias_init) if bias else None
self.bias_gain = lr_multiplier
if w_dim is not None:
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
self.affine_beta = FullyConnectedLayer(w_dim, in_channels, bias_init=0)
def forward(self, x, w=None, s=None):
tops.assert_shape(x, [None, self.weight.shape[1], None, None])
if s is not None:
s = s[..., :self.in_channels * 2]
gamma, beta = s.view(-1, self.in_channels * 2, 1, 1).chunk(2, dim=1)
x = fma(x, gamma, beta)
elif hasattr(self, "affine"):
gamma = self.affine(w).view(-1, self.in_channels, 1, 1)
beta = self.affine_beta(w).view(-1, self.in_channels, 1, 1)
x = fma(x, gamma, beta)
w = self.weight * self.weight_gain
# Removing flip weight is not safe.
x = conv2d_resample.conv2d_resample(x, w, self.resample_filter, self.up,
self.down, self.padding, flip_weight=self.up == 1)
if hasattr(self, "norm"):
x = self.norm(x)
b = self.bias * self.bias_gain if self.bias is not None else None
x = bias_act.bias_act(x, b, act=self.activation, gain=self.act_gain, clamp=self.conv_clamp)
return x
def extra_repr(self) -> str:
return ", ".join([f"{key}={item}" for key, item in self.repr.items()])
class Block(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels, 0 = first block.
out_channels, # Number of output channels.
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
up=1,
down=1,
**layer_kwargs, # Arguments for SynthesisLayer.
):
super().__init__()
self.in_channels = in_channels
self.down = down
self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, **layer_kwargs)
def forward(self, x, **layer_kwargs):
x = self.conv0(x, **layer_kwargs)
x = self.conv1(x, **layer_kwargs)
return x
class ResidualBlock(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels, 0 = first block.
out_channels, # Number of output channels.
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
up=1,
down=1,
gain_out=np.sqrt(0.5),
fix_residual: bool = False,
**layer_kwargs, # Arguments for conv layer.
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.down = down
self.conv0 = Conv2d(in_channels, out_channels, down=down, conv_clamp=conv_clamp, **layer_kwargs)
self.conv1 = Conv2d(out_channels, out_channels, up=up, conv_clamp=conv_clamp, gain=gain_out, **layer_kwargs)
self.skip = Conv2d(
in_channels, out_channels, kernel_size=1, bias=False, up=up, down=down,
activation="linear" if fix_residual else "lrelu",
gain=gain_out
)
self.gain_out = gain_out
def forward(self, x, w=None, s=None, **layer_kwargs):
y = self.skip(x)
s_ = next(s) if s is not None else None
x = self.conv0(x, w, s=s_, **layer_kwargs)
s_ = next(s) if s is not None else None
x = self.conv1(x, w, s=s_, **layer_kwargs)
x = y + x
return x
class MinibatchStdLayer(torch.nn.Module):
def __init__(self, group_size, num_channels=1):
super().__init__()
self.group_size = group_size
self.num_channels = num_channels
def forward(self, x):
N, C, H, W = x.shape
with tops.suppress_tracer_warnings(): # as_tensor results are registered as constants
G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
F = self.num_channels
c = C // F
# [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
y = x.reshape(G, -1, F, c, H, W)
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
y = y.mean(dim=[2, 3, 4]) # [nF] Take average over channels and pixels.
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
return x
class DiscriminatorEpilogue(torch.nn.Module):
def __init__(self,
in_channels, # Number of input channels.
resolution: List[int], # Resolution of this block.
mbstd_group_size=4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
mbstd_num_channels=1, # Number of features for the minibatch standard deviation layer, 0 = disable.
activation='lrelu', # Activation function: 'relu', 'lrelu', etc.
conv_clamp=None, # Clamp the output of convolution layers to +-X, None = disable clamping.
):
super().__init__()
self.in_channels = in_channels
self.resolution = resolution
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size,
num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
self.conv = Conv2d(
in_channels + mbstd_num_channels, in_channels,
kernel_size=3, activation=activation, conv_clamp=conv_clamp)
self.fc = FullyConnectedLayer(in_channels * resolution[0] * resolution[1], in_channels, activation=activation)
self.out = FullyConnectedLayer(in_channels, 1)
def forward(self, x):
tops.assert_shape(x, [None, self.in_channels, *self.resolution]) # [NCHW]
# Main layers.
if self.mbstd is not None:
x = self.mbstd(x)
x = self.conv(x)
x = self.fc(x.flatten(1))
x = self.out(x)
return x
|