|
|
|
|
|
|
|
|
|
|
|
"""PyTorch utilities""" |
|
from collections import OrderedDict |
|
from itertools import islice |
|
import math |
|
import operator |
|
from typing import Optional, Union |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
def xaviermultiplier(m, gain): |
|
if isinstance(m, nn.Conv1d): |
|
ksize = m.kernel_size[0] |
|
n1 = m.in_channels |
|
n2 = m.out_channels |
|
|
|
std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) |
|
elif isinstance(m, nn.ConvTranspose1d): |
|
ksize = m.kernel_size[0] // m.stride[0] |
|
n1 = m.in_channels |
|
n2 = m.out_channels |
|
|
|
std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) |
|
elif isinstance(m, nn.Conv2d): |
|
ksize = m.kernel_size[0] * m.kernel_size[1] |
|
n1 = m.in_channels |
|
n2 = m.out_channels |
|
|
|
std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) |
|
elif isinstance(m, nn.ConvTranspose2d): |
|
ksize = m.kernel_size[0] * m.kernel_size[1] // m.stride[0] // m.stride[1] |
|
n1 = m.in_channels |
|
n2 = m.out_channels |
|
|
|
std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) |
|
elif isinstance(m, nn.Conv3d): |
|
ksize = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] |
|
n1 = m.in_channels |
|
n2 = m.out_channels |
|
|
|
std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) |
|
elif isinstance(m, nn.ConvTranspose3d): |
|
ksize = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] // m.stride[0] // m.stride[1] // m.stride[2] |
|
n1 = m.in_channels |
|
n2 = m.out_channels |
|
|
|
std = gain * math.sqrt(2.0 / ((n1 + n2) * ksize)) |
|
elif isinstance(m, nn.Linear): |
|
n1 = m.in_features |
|
n2 = m.out_features |
|
|
|
std = gain * math.sqrt(2.0 / (n1 + n2)) |
|
else: |
|
return None |
|
|
|
return std |
|
|
|
|
|
def xavier_uniform_(m, gain): |
|
std = xaviermultiplier(m, gain) |
|
m.weight.data.uniform_(-std * math.sqrt(3.0), std * math.sqrt(3.0)) |
|
|
|
def initmod(m, gain=1.0, weightinitfunc=xavier_uniform_): |
|
validclasses = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d] |
|
if any([isinstance(m, x) for x in validclasses]): |
|
weightinitfunc(m, gain) |
|
if hasattr(m, 'bias') and isinstance(m.bias, torch.Tensor): |
|
m.bias.data.zero_() |
|
|
|
|
|
if isinstance(m, nn.ConvTranspose2d): |
|
|
|
m.weight.data[:, :, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] |
|
m.weight.data[:, :, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2] |
|
m.weight.data[:, :, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2] |
|
|
|
if isinstance(m, nn.ConvTranspose3d): |
|
|
|
m.weight.data[:, :, 0::2, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] |
|
m.weight.data[:, :, 0::2, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] |
|
m.weight.data[:, :, 0::2, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] |
|
m.weight.data[:, :, 1::2, 0::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] |
|
m.weight.data[:, :, 1::2, 0::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] |
|
m.weight.data[:, :, 1::2, 1::2, 0::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] |
|
m.weight.data[:, :, 1::2, 1::2, 1::2] = m.weight.data[:, :, 0::2, 0::2, 0::2] |
|
|
|
if isinstance(m, Conv2dWNUB) or isinstance(m, Conv2dWN) or isinstance(m, ConvTranspose2dWN) or \ |
|
isinstance(m, ConvTranspose2dWNUB) or isinstance(m, LinearWN): |
|
norm = np.sqrt(torch.sum(m.weight.data[:] ** 2)) |
|
m.g.data[:] = norm |
|
|
|
def initseq(s): |
|
for a, b in zip(s[:-1], s[1:]): |
|
if isinstance(b, nn.ReLU): |
|
initmod(a, nn.init.calculate_gain('relu')) |
|
elif isinstance(b, nn.LeakyReLU): |
|
initmod(a, nn.init.calculate_gain('leaky_relu', b.negative_slope)) |
|
elif isinstance(b, nn.Sigmoid): |
|
initmod(a) |
|
elif isinstance(b, nn.Softplus): |
|
initmod(a) |
|
else: |
|
initmod(a) |
|
|
|
initmod(s[-1]) |
|
|
|
|
|
class LinearWN(nn.Linear): |
|
def __init__(self, in_features, out_features, bias=True): |
|
super(LinearWN, self).__init__(in_features, out_features, bias) |
|
self.g = nn.Parameter(torch.ones(out_features)) |
|
self.fused = False |
|
|
|
def fuse(self): |
|
wnorm = torch.sqrt(torch.sum(self.weight ** 2)) |
|
self.weight.data = self.weight.data * self.g.data[:, None] / wnorm |
|
self.fused = True |
|
|
|
def forward(self, input): |
|
if self.fused: |
|
return F.linear(input, self.weight, self.bias) |
|
else: |
|
wnorm = torch.sqrt(torch.sum(self.weight ** 2)) |
|
return F.linear(input, self.weight * self.g[:, None] / wnorm, self.bias) |
|
|
|
class LinearELR(nn.Module): |
|
"""Linear layer with equalized learning rate from stylegan2""" |
|
def __init__(self, inch, outch, lrmult=1., norm : Optional[str]=None, act=None): |
|
super(LinearELR, self).__init__() |
|
|
|
|
|
try: |
|
if isinstance(act, nn.LeakyReLU): |
|
actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) |
|
elif isinstance(act, nn.ReLU): |
|
actgain = nn.init.calculate_gain("relu") |
|
else: |
|
actgain = nn.init.calculate_gain(act) |
|
except: |
|
actgain = 1. |
|
|
|
initgain = 1. / math.sqrt(inch) |
|
|
|
self.weight = nn.Parameter(torch.randn(outch, inch) / lrmult) |
|
self.weightgain = actgain |
|
|
|
if norm == None: |
|
self.weightgain = self.weightgain * initgain * lrmult |
|
|
|
self.bias = nn.Parameter(torch.full([outch], 0.)) |
|
|
|
self.norm : Optional[str] = norm |
|
self.act = act |
|
|
|
self.fused = False |
|
|
|
def extra_repr(self): |
|
return 'inch={}, outch={}, norm={}, act={}'.format( |
|
self.weight.size(1), self.weight.size(0), self.norm, self.act |
|
) |
|
|
|
def getweight(self): |
|
if self.fused: |
|
return self.weight |
|
else: |
|
weight = self.weight |
|
if self.norm is not None: |
|
if self.norm == "demod": |
|
weight = F.normalize(weight, dim=1) |
|
return weight |
|
|
|
def fuse(self): |
|
if not self.fused: |
|
with torch.no_grad(): |
|
self.weight.data = self.getweight() * self.weightgain |
|
self.fused = True |
|
|
|
def forward(self, x): |
|
if self.fused: |
|
weight = self.getweight() |
|
|
|
out = torch.addmm(self.bias[None], x, weight.t()) |
|
if self.act is not None: |
|
out = self.act(out) |
|
return out |
|
else: |
|
weight = self.getweight() |
|
|
|
if self.act is None: |
|
out = torch.addmm(self.bias[None], x, weight.t(), alpha=self.weightgain) |
|
return out |
|
else: |
|
out = F.linear(x, weight * self.weightgain, bias=self.bias) |
|
out = self.act(out) |
|
return out |
|
|
|
class Downsample2d(nn.Module): |
|
def __init__(self, nchannels, stride=1, padding=0): |
|
super(Downsample2d, self).__init__() |
|
|
|
self.nchannels = nchannels |
|
self.stride = stride |
|
self.padding = padding |
|
|
|
blurkernel = torch.tensor([1., 6., 15., 20., 15., 6., 1.]) |
|
blurkernel = (blurkernel[:, None] * blurkernel[None, :]) |
|
blurkernel = blurkernel / torch.sum(blurkernel) |
|
blurkernel = blurkernel[None, None, :, :].repeat(nchannels, 1, 1, 1) |
|
self.register_buffer('kernel', blurkernel) |
|
|
|
def forward(self, x): |
|
if self.padding == "reflect": |
|
x = F.pad(x, (3, 3, 3, 3), mode='reflect') |
|
return F.conv2d(x, weight=self.kernel, stride=self.stride, padding=0, groups=self.nchannels) |
|
else: |
|
return F.conv2d(x, weight=self.kernel, stride=self.stride, padding=self.padding, groups=self.nchannels) |
|
|
|
class Dilate2d(nn.Module): |
|
def __init__(self, nchannels, kernelsize, stride=1, padding=0): |
|
super(Dilate2d, self).__init__() |
|
|
|
self.nchannels = nchannels |
|
self.kernelsize = kernelsize |
|
self.stride = stride |
|
self.padding = padding |
|
|
|
blurkernel = torch.ones((self.kernelsize,)) |
|
blurkernel = (blurkernel[:, None] * blurkernel[None, :]) |
|
blurkernel = blurkernel / torch.sum(blurkernel) |
|
blurkernel = blurkernel[None, None, :, :].repeat(nchannels, 1, 1, 1) |
|
self.register_buffer('kernel', blurkernel) |
|
|
|
def forward(self, x): |
|
return F.conv2d(x, weight=self.kernel, stride=self.stride, padding=self.padding, groups=self.nchannels).clamp(max=1.) |
|
|
|
class Conv2dWN(nn.Conv2d): |
|
def __init__(self, in_channels, out_channels, kernel_size, |
|
stride=1, padding=0, dilation=1, groups=1, bias=True): |
|
super(Conv2dWN, self).__init__(in_channels, out_channels, kernel_size, stride, |
|
padding, dilation, groups, True) |
|
self.g = nn.Parameter(torch.ones(out_channels)) |
|
|
|
def forward(self, x): |
|
wnorm = torch.sqrt(torch.sum(self.weight ** 2)) |
|
return F.conv2d(x, self.weight * self.g[:, None, None, None] / wnorm, |
|
bias=self.bias, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups) |
|
|
|
class Conv2dUB(nn.Conv2d): |
|
def __init__(self, in_channels, out_channels, height, width, kernel_size, |
|
stride=1, padding=0, dilation=1, groups=1, bias=False): |
|
super(Conv2dUB, self).__init__(in_channels, out_channels, kernel_size, stride, |
|
padding, dilation, groups, False) |
|
self.bias = nn.Parameter(torch.zeros(out_channels, height, width)) |
|
|
|
def forward(self, x): |
|
return F.conv2d(x, self.weight, |
|
bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups) + self.bias[None, ...] |
|
|
|
class Conv2dWNUB(nn.Conv2d): |
|
def __init__(self, in_channels, out_channels, height, width, kernel_size, |
|
stride=1, padding=0, dilation=1, groups=1, bias=False): |
|
super(Conv2dWNUB, self).__init__(in_channels, out_channels, kernel_size, stride, |
|
padding, dilation, groups, False) |
|
self.g = nn.Parameter(torch.ones(out_channels)) |
|
self.bias = nn.Parameter(torch.zeros(out_channels, height, width)) |
|
|
|
def forward(self, x): |
|
wnorm = torch.sqrt(torch.sum(self.weight ** 2)) |
|
return F.conv2d(x, self.weight * self.g[:, None, None, None] / wnorm, |
|
bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups) + self.bias[None, ...] |
|
|
|
def blockinit(k, stride): |
|
dim = k.ndim - 2 |
|
return k \ |
|
.view(k.size(0), k.size(1), *(x for i in range(dim) for x in (k.size(i+2), 1))) \ |
|
.repeat(1, 1, *(x for i in range(dim) for x in (1, stride))) \ |
|
.view(k.size(0), k.size(1), *(k.size(i+2)*stride for i in range(dim))) |
|
|
|
class ConvTranspose1dELR(nn.Module): |
|
def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): |
|
super(ConvTranspose1dELR, self).__init__() |
|
|
|
self.inch = inch |
|
self.outch = outch |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.wsize = wsize |
|
self.norm = norm |
|
self.ub = ub |
|
self.act = act |
|
|
|
|
|
try: |
|
if isinstance(act, nn.LeakyReLU): |
|
actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) |
|
elif isinstance(act, nn.ReLU): |
|
actgain = nn.init.calculate_gain("relu") |
|
else: |
|
actgain = nn.init.calculate_gain(act) |
|
except: |
|
actgain = 1. |
|
|
|
fan_in = inch * (kernel_size / (stride)) |
|
|
|
initgain = stride ** 0.5 if norm == "demod" else 1. / math.sqrt(fan_in) |
|
|
|
self.weightgain = actgain * initgain |
|
|
|
self.weight = nn.Parameter(blockinit( |
|
torch.randn(inch, outch, kernel_size//self.stride), self.stride)) |
|
|
|
if ub is not None: |
|
self.bias = nn.Parameter(torch.zeros(outch, ub[0])) |
|
else: |
|
self.bias = nn.Parameter(torch.zeros(outch)) |
|
|
|
if wsize > 0: |
|
self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) |
|
else: |
|
self.affine = None |
|
|
|
self.fused = False |
|
|
|
def extra_repr(self): |
|
return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( |
|
self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act |
|
) |
|
|
|
def getweight(self, weight): |
|
if self.fused: |
|
return weight |
|
else: |
|
if self.norm is not None: |
|
if self.norm == "demod": |
|
if weight.ndim == 5: |
|
normdims = [1, 3] |
|
else: |
|
normdims = [0, 2] |
|
|
|
if torch.jit.is_scripting(): |
|
|
|
weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) |
|
else: |
|
weight = F.normalize(weight, dim=normdims) |
|
|
|
weight = weight * self.weightgain |
|
|
|
return weight |
|
|
|
def fuse(self): |
|
if self.affine is None: |
|
with torch.no_grad(): |
|
self.weight.data = self.getweight(self.weight) |
|
self.fused = True |
|
|
|
def forward(self, x, w : Optional[torch.Tensor]=None): |
|
b = x.size(0) |
|
|
|
if self.affine is not None and w is not None: |
|
|
|
affine = self.affine(w)[:, :, None, None] |
|
weight = self.weight * (affine * 0.1 + 1.) |
|
else: |
|
weight = self.weight |
|
|
|
weight = self.getweight(weight) |
|
|
|
if self.affine is not None and w is not None: |
|
x = x.view(1, b * self.inch, x.size(2)) |
|
weight = weight.view(b * self.inch, self.outch, self.kernel_size) |
|
groups = b |
|
else: |
|
groups = 1 |
|
|
|
out = F.conv_transpose1d(x, weight, None, |
|
stride=self.stride, padding=self.padding, dilation=1, groups=groups) |
|
|
|
if self.affine is not None and w is not None: |
|
out = out.view(b, self.outch, out.size(2)) |
|
|
|
if self.bias.ndim == 1: |
|
bias = self.bias[None, :, None] |
|
else: |
|
bias = self.bias[None, :, :] |
|
out = out + bias |
|
|
|
if self.act is not None: |
|
out = self.act(out) |
|
|
|
return out |
|
|
|
class ConvTranspose2dELR(nn.Module): |
|
def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): |
|
super(ConvTranspose2dELR, self).__init__() |
|
|
|
self.inch = inch |
|
self.outch = outch |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.wsize = wsize |
|
self.norm = norm |
|
self.ub = ub |
|
self.act = act |
|
|
|
|
|
try: |
|
if isinstance(act, nn.LeakyReLU): |
|
actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) |
|
elif isinstance(act, nn.ReLU): |
|
actgain = nn.init.calculate_gain("relu") |
|
else: |
|
actgain = nn.init.calculate_gain(act) |
|
except: |
|
actgain = 1. |
|
|
|
fan_in = inch * (kernel_size ** 2 / (stride ** 2)) |
|
|
|
initgain = stride if norm == "demod" else 1. / math.sqrt(fan_in) |
|
|
|
self.weightgain = actgain * initgain |
|
|
|
self.weight = nn.Parameter(blockinit( |
|
torch.randn(inch, outch, kernel_size//self.stride, kernel_size//self.stride), self.stride)) |
|
|
|
if ub is not None: |
|
self.bias = nn.Parameter(torch.zeros(outch, ub[0], ub[1])) |
|
else: |
|
self.bias = nn.Parameter(torch.zeros(outch)) |
|
|
|
if wsize > 0: |
|
self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) |
|
else: |
|
self.affine = None |
|
|
|
self.fused = False |
|
|
|
def extra_repr(self): |
|
return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( |
|
self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act |
|
) |
|
|
|
def getweight(self, weight): |
|
if self.fused: |
|
return weight |
|
else: |
|
if self.norm is not None: |
|
if self.norm == "demod": |
|
if weight.ndim == 5: |
|
normdims = [1, 3, 4] |
|
else: |
|
normdims = [0, 2, 3] |
|
|
|
if torch.jit.is_scripting(): |
|
|
|
weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) |
|
else: |
|
weight = F.normalize(weight, dim=normdims) |
|
|
|
weight = weight * self.weightgain |
|
|
|
return weight |
|
|
|
def fuse(self): |
|
if self.affine is None: |
|
with torch.no_grad(): |
|
self.weight.data = self.getweight(self.weight) |
|
self.fused = True |
|
|
|
def forward(self, x, w : Optional[torch.Tensor]=None): |
|
b = x.size(0) |
|
|
|
if self.affine is not None and w is not None: |
|
|
|
affine = self.affine(w)[:, :, None, None, None] |
|
weight = self.weight * (affine * 0.1 + 1.) |
|
else: |
|
weight = self.weight |
|
|
|
weight = self.getweight(weight) |
|
|
|
if self.affine is not None and w is not None: |
|
x = x.view(1, b * self.inch, x.size(2), x.size(3)) |
|
weight = weight.view(b * self.inch, self.outch, self.kernel_size, self.kernel_size) |
|
groups = b |
|
else: |
|
groups = 1 |
|
|
|
out = F.conv_transpose2d(x, weight, None, |
|
stride=self.stride, padding=self.padding, dilation=1, groups=groups) |
|
|
|
if self.affine is not None and w is not None: |
|
out = out.view(b, self.outch, out.size(2), out.size(3)) |
|
|
|
if self.bias.ndim == 1: |
|
bias = self.bias[None, :, None, None] |
|
else: |
|
bias = self.bias[None, :, :, :] |
|
out = out + bias |
|
|
|
if self.act is not None: |
|
out = self.act(out) |
|
|
|
return out |
|
|
|
class ConvTranspose3dELR(nn.Module): |
|
def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): |
|
super(ConvTranspose3dELR, self).__init__() |
|
|
|
self.inch = inch |
|
self.outch = outch |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.wsize = wsize |
|
self.norm = norm |
|
self.ub = ub |
|
self.act = act |
|
|
|
|
|
try: |
|
if isinstance(act, nn.LeakyReLU): |
|
actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) |
|
elif isinstance(act, nn.ReLU): |
|
actgain = nn.init.calculate_gain("relu") |
|
else: |
|
actgain = nn.init.calculate_gain(act) |
|
except: |
|
actgain = 1. |
|
|
|
fan_in = inch * (kernel_size ** 3 / (stride ** 3)) |
|
|
|
initgain = stride ** 1.5 if norm == "demod" else 1. / math.sqrt(fan_in) |
|
|
|
self.weightgain = actgain * initgain |
|
|
|
self.weight = nn.Parameter(blockinit( |
|
torch.randn(inch, outch, kernel_size//self.stride, kernel_size//self.stride, kernel_size//self.stride), self.stride)) |
|
|
|
if ub is not None: |
|
self.bias = nn.Parameter(torch.zeros(outch, ub[0], ub[1], ub[2])) |
|
else: |
|
self.bias = nn.Parameter(torch.zeros(outch)) |
|
|
|
if wsize > 0: |
|
self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) |
|
else: |
|
self.affine = None |
|
|
|
self.fused = False |
|
|
|
def extra_repr(self): |
|
return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( |
|
self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act |
|
) |
|
|
|
def getweight(self, weight): |
|
if self.fused: |
|
return weight |
|
else: |
|
if self.norm is not None: |
|
if self.norm == "demod": |
|
if weight.ndim == 5: |
|
normdims = [1, 3, 4, 5] |
|
else: |
|
normdims = [0, 2, 3, 4] |
|
|
|
if torch.jit.is_scripting(): |
|
|
|
weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) |
|
else: |
|
weight = F.normalize(weight, dim=normdims) |
|
|
|
weight = weight * self.weightgain |
|
|
|
return weight |
|
|
|
def fuse(self): |
|
if self.affine is None: |
|
with torch.no_grad(): |
|
self.weight.data = self.getweight(self.weight) |
|
self.fused = True |
|
|
|
def forward(self, x, w : Optional[torch.Tensor]=None): |
|
b = x.size(0) |
|
|
|
if self.affine is not None and w is not None: |
|
|
|
affine = self.affine(w)[:, :, None, None, None, None] |
|
weight = self.weight * (affine * 0.1 + 1.) |
|
else: |
|
weight = self.weight |
|
|
|
weight = self.getweight(weight) |
|
|
|
if self.affine is not None and w is not None: |
|
x = x.view(1, b * self.inch, x.size(2), x.size(3), x.size(4)) |
|
weight = weight.view(b * self.inch, self.outch, self.kernel_size, self.kernel_size, self.kernel_size) |
|
groups = b |
|
else: |
|
groups = 1 |
|
|
|
out = F.conv_transpose3d(x, weight, None, |
|
stride=self.stride, padding=self.padding, dilation=1, groups=groups) |
|
|
|
if self.affine is not None and w is not None: |
|
out = out.view(b, self.outch, out.size(2), out.size(3), out.size(4)) |
|
|
|
if self.bias.ndim == 1: |
|
bias = self.bias[None, :, None, None, None] |
|
else: |
|
bias = self.bias[None, :, :, :, :] |
|
out = out + bias |
|
|
|
if self.act is not None: |
|
out = self.act(out) |
|
|
|
return out |
|
|
|
class Conv2dELR(nn.Module): |
|
def __init__(self, inch, outch, kernel_size, stride, padding, wsize=0, affinelrmult=1., norm=None, ub=None, act=None): |
|
super(Conv2dELR, self).__init__() |
|
|
|
self.inch = inch |
|
self.outch = outch |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.wsize = wsize |
|
self.norm = norm |
|
self.ub = ub |
|
self.act = act |
|
|
|
|
|
try: |
|
if isinstance(act, nn.LeakyReLU): |
|
actgain = nn.init.calculate_gain("leaky_relu", act.negative_slope) |
|
elif isinstance(act, nn.ReLU): |
|
actgain = nn.init.calculate_gain("relu") |
|
else: |
|
actgain = nn.init.calculate_gain(act) |
|
except: |
|
actgain = 1. |
|
|
|
fan_in = inch * (kernel_size ** 2) |
|
|
|
initgain = 1. if norm == "demod" else 1. / math.sqrt(fan_in) |
|
|
|
self.weightgain = actgain * initgain |
|
|
|
self.weight = nn.Parameter( |
|
torch.randn(outch, inch, kernel_size, kernel_size)) |
|
|
|
if ub is not None: |
|
self.bias = nn.Parameter(torch.zeros(outch, ub[0], ub[1])) |
|
else: |
|
self.bias = nn.Parameter(torch.zeros(outch)) |
|
|
|
if wsize > 0: |
|
self.affine = LinearELR(wsize, inch, lrmult=affinelrmult) |
|
else: |
|
self.affine = None |
|
|
|
self.fused = False |
|
|
|
def extra_repr(self): |
|
return 'inch={}, outch={}, kernel_size={}, stride={}, padding={}, wsize={}, norm={}, ub={}, act={}'.format( |
|
self.inch, self.outch, self.kernel_size, self.stride, self.padding, self.wsize, self.norm, self.ub, self.act |
|
) |
|
|
|
def getweight(self, weight): |
|
if self.fused: |
|
return weight |
|
else: |
|
if self.norm is not None: |
|
if self.norm == "demod": |
|
if weight.ndim == 5: |
|
normdims = [2, 3, 4] |
|
else: |
|
normdims = [1, 2, 3] |
|
|
|
if torch.jit.is_scripting(): |
|
|
|
weight = weight / torch.linalg.vector_norm(weight, dim=normdims, keepdim=True) |
|
else: |
|
weight = F.normalize(weight, dim=normdims) |
|
|
|
weight = weight * self.weightgain |
|
|
|
return weight |
|
|
|
def fuse(self): |
|
if self.affine is None: |
|
with torch.no_grad(): |
|
self.weight.data = self.getweight(self.weight) |
|
self.fused = True |
|
|
|
def forward(self, x, w : Optional[torch.Tensor]=None): |
|
b = x.size(0) |
|
|
|
if self.affine is not None and w is not None: |
|
|
|
affine = self.affine(w)[:, None, :, None, None] |
|
weight = self.weight * (affine * 0.1 + 1.) |
|
else: |
|
weight = self.weight |
|
|
|
weight = self.getweight(weight) |
|
|
|
if self.affine is not None and w is not None: |
|
x = x.view(1, b * self.inch, x.size(2), x.size(3)) |
|
weight = weight.view(b * self.outch, self.inch, self.kernel_size, self.kernel_size) |
|
groups = b |
|
else: |
|
groups = 1 |
|
|
|
out = F.conv2d(x, weight, None, |
|
stride=self.stride, padding=self.padding, dilation=1, groups=groups) |
|
|
|
if self.affine is not None and w is not None: |
|
out = out.view(b, self.outch, out.size(2), out.size(3)) |
|
|
|
if self.bias.ndim == 1: |
|
bias = self.bias[None, :, None, None] |
|
else: |
|
bias = self.bias[None, :, :, :] |
|
out = out + bias |
|
|
|
if self.act is not None: |
|
out = self.act(out) |
|
|
|
return out |
|
|
|
class ConvTranspose2dWN(nn.ConvTranspose2d): |
|
def __init__(self, in_channels, out_channels, kernel_size, |
|
stride=1, padding=0, dilation=1, groups=1, bias=True): |
|
super(ConvTranspose2dWN, self).__init__(in_channels, out_channels, kernel_size, stride, |
|
padding, dilation, groups, True) |
|
self.g = nn.Parameter(torch.ones(out_channels)) |
|
self.fused = False |
|
|
|
def fuse(self): |
|
wnorm = torch.sqrt(torch.sum(self.weight ** 2)) |
|
self.weight.data = self.weight.data * self.g.data[None, :, None, None] / wnorm |
|
self.fused = True |
|
|
|
def forward(self, x): |
|
bias = self.bias |
|
assert bias is not None |
|
if self.fused: |
|
return F.conv_transpose2d(x, self.weight, |
|
bias=self.bias, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups) |
|
else: |
|
wnorm = torch.sqrt(torch.sum(self.weight ** 2)) |
|
return F.conv_transpose2d(x, self.weight * self.g[None, :, None, None] / wnorm, |
|
bias=self.bias, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups) |
|
|
|
class ConvTranspose2dUB(nn.ConvTranspose2d): |
|
def __init__(self, width, height, in_channels, out_channels, kernel_size, |
|
stride=1, padding=0, dilation=1, groups=1, bias=False): |
|
super(ConvTranspose2dUB, self).__init__(in_channels, out_channels, kernel_size, stride, |
|
padding, dilation, groups, False) |
|
self.bias_ = nn.Parameter(torch.zeros(out_channels, height, width)) |
|
|
|
def forward(self, x): |
|
return F.conv_transpose2d(x, self.weight, |
|
bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups) + self.bias_[None, ...] |
|
|
|
class ConvTranspose2dWNUB(nn.ConvTranspose2d): |
|
def __init__(self, in_channels, out_channels, height, width, kernel_size, |
|
stride=1, padding=0, dilation=1, groups=1, bias=False): |
|
super(ConvTranspose2dWNUB, self).__init__(in_channels, out_channels, kernel_size, stride, |
|
padding, dilation, groups, False) |
|
self.g = nn.Parameter(torch.ones(out_channels)) |
|
self.bias = nn.Parameter(torch.zeros(out_channels, height, width)) |
|
|
|
self.fused = False |
|
|
|
def fuse(self): |
|
wnorm = torch.sqrt(torch.sum(self.weight ** 2)) |
|
self.weight.data = self.weight.data * self.g.data[None, :, None, None] / wnorm |
|
self.fused = True |
|
|
|
def forward(self, x): |
|
bias = self.bias |
|
assert bias is not None |
|
if self.fused: |
|
return F.conv_transpose2d(x, self.weight, |
|
bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups) + bias[None, ...] |
|
else: |
|
wnorm = torch.sqrt(torch.sum(self.weight ** 2)) |
|
return F.conv_transpose2d(x, self.weight * self.g[None, :, None, None] / wnorm, |
|
bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups) + bias[None, ...] |
|
|
|
class Conv3dUB(nn.Conv3d): |
|
def __init__(self, width, height, depth, in_channels, out_channels, kernel_size, |
|
stride=1, padding=0, dilation=1, groups=1, bias=True): |
|
super(Conv3dUB, self).__init__(in_channels, out_channels, kernel_size, stride, |
|
padding, dilation, groups, False) |
|
self.bias = nn.Parameter(torch.zeros(out_channels, depth, height, width)) |
|
|
|
def forward(self, x): |
|
return F.conv3d(x, self.weight, |
|
bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups) + self.bias[None, ...] |
|
|
|
class ConvTranspose3dUB(nn.ConvTranspose3d): |
|
def __init__(self, width, height, depth, in_channels, out_channels, kernel_size, |
|
stride=1, padding=0, dilation=1, groups=1, bias=True): |
|
super(ConvTranspose3dUB, self).__init__(in_channels, out_channels, kernel_size, stride, |
|
padding, dilation, groups, False) |
|
self.bias = nn.Parameter(torch.zeros(out_channels, depth, height, width)) |
|
|
|
def forward(self, x): |
|
return F.conv_transpose3d(x, self.weight, |
|
bias=None, stride=self.stride, padding=self.padding, |
|
dilation=self.dilation, groups=self.groups) + self.bias[None, ...] |
|
|
|
class Rodrigues(nn.Module): |
|
def __init__(self): |
|
super(Rodrigues, self).__init__() |
|
|
|
def forward(self, rvec): |
|
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) |
|
rvec = rvec / theta[:, None] |
|
costh = torch.cos(theta) |
|
sinth = torch.sin(theta) |
|
return torch.stack(( |
|
rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, |
|
rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, |
|
rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, |
|
|
|
rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, |
|
rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, |
|
rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, |
|
|
|
rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, |
|
rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, |
|
rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) |
|
|
|
class Quaternion(nn.Module): |
|
def __init__(self): |
|
super(Quaternion, self).__init__() |
|
|
|
def forward(self, rvec): |
|
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) |
|
rvec = rvec / theta[:, None] |
|
return torch.stack(( |
|
1. - 2. * rvec[:, 1] ** 2 - 2. * rvec[:, 2] ** 2, |
|
2. * (rvec[:, 0] * rvec[:, 1] - rvec[:, 2] * rvec[:, 3]), |
|
2. * (rvec[:, 0] * rvec[:, 2] + rvec[:, 1] * rvec[:, 3]), |
|
|
|
2. * (rvec[:, 0] * rvec[:, 1] + rvec[:, 2] * rvec[:, 3]), |
|
1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 2] ** 2, |
|
2. * (rvec[:, 1] * rvec[:, 2] - rvec[:, 0] * rvec[:, 3]), |
|
|
|
2. * (rvec[:, 0] * rvec[:, 2] - rvec[:, 1] * rvec[:, 3]), |
|
2. * (rvec[:, 0] * rvec[:, 3] + rvec[:, 1] * rvec[:, 2]), |
|
1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 1] ** 2 |
|
), dim=1).view(-1, 3, 3) |
|
|
|
class BufferDict(nn.Module): |
|
def __init__(self, d, persistent=False): |
|
super(BufferDict, self).__init__() |
|
|
|
for k in d: |
|
self.register_buffer(k, d[k], persistent=False) |
|
|
|
def __getitem__(self, key): |
|
return self._buffers[key] |
|
|
|
def __setitem__(self, key, parameter): |
|
self.register_buffer(key, parameter, persistent=False) |
|
|
|
def matrix_to_axisangle(r): |
|
th = torch.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.))[..., None] |
|
vec = 0.5 * torch.stack([ |
|
r[..., 2, 1] - r[..., 1, 2], |
|
r[..., 0, 2] - r[..., 2, 0], |
|
r[..., 1, 0] - r[..., 0, 1]], dim=-1) / torch.sin(th) |
|
return th, vec |
|
|
|
@torch.jit.script |
|
def axisangle_to_matrix(rvec : torch.Tensor): |
|
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=-1)) |
|
rvec = rvec / theta[..., None] |
|
costh = torch.cos(theta) |
|
sinth = torch.sin(theta) |
|
return torch.stack(( |
|
torch.stack((rvec[..., 0] ** 2 + (1. - rvec[..., 0] ** 2) * costh, |
|
rvec[..., 0] * rvec[..., 1] * (1. - costh) - rvec[..., 2] * sinth, |
|
rvec[..., 0] * rvec[..., 2] * (1. - costh) + rvec[..., 1] * sinth), dim=-1), |
|
|
|
torch.stack((rvec[..., 0] * rvec[..., 1] * (1. - costh) + rvec[..., 2] * sinth, |
|
rvec[..., 1] ** 2 + (1. - rvec[..., 1] ** 2) * costh, |
|
rvec[..., 1] * rvec[..., 2] * (1. - costh) - rvec[..., 0] * sinth), dim=-1), |
|
|
|
torch.stack((rvec[..., 0] * rvec[..., 2] * (1. - costh) - rvec[..., 1] * sinth, |
|
rvec[..., 1] * rvec[..., 2] * (1. - costh) + rvec[..., 0] * sinth, |
|
rvec[..., 2] ** 2 + (1. - rvec[..., 2] ** 2) * costh), dim=-1)), |
|
dim=-2) |
|
|
|
def rotation_interp(r0, r1, alpha): |
|
r0a = r0.view(-1, 3, 3) |
|
r1a = r1.view(-1, 3, 3) |
|
r = torch.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0) |
|
|
|
th, rvec = matrix_to_axisangle(r) |
|
rvec = rvec * (alpha * th) |
|
|
|
r = axisangle_to_matrix(rvec) |
|
return torch.bmm(r0a, r.view(-1, 3, 3)).view_as(r0) |
|
|
|
def fuse(trainiter=None, renderoptions={}): |
|
def _fuse(m): |
|
if hasattr(m, "fuse") and isinstance(m, torch.nn.Module): |
|
if m.fuse.__code__.co_argcount > 1: |
|
m.fuse(trainiter, renderoptions) |
|
else: |
|
m.fuse() |
|
return _fuse |
|
|
|
def no_grad(m): |
|
for p in m.parameters(): |
|
p.requires_grad = False |
|
|