|
|
|
|
|
|
|
|
|
|
|
"""Neural Volumes decoder """ |
|
import math |
|
from typing import Optional, Dict, List |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import models.utils |
|
from models.utils import LinearELR, ConvTranspose2dELR, ConvTranspose3dELR |
|
|
|
class Reshape(nn.Module): |
|
def __init__(self, *args): |
|
super(Reshape, self).__init__() |
|
self.shape = args |
|
|
|
def forward(self, x): |
|
return x.view(self.shape) |
|
|
|
class ContentDecoder(nn.Module): |
|
def __init__(self, primsize, inch, outch, chstart=256, hstart=4, |
|
texwarp=False, elr=True, norm=None, mod=False, ub=True, upconv=None, |
|
penultch=None): |
|
super(ContentDecoder, self).__init__() |
|
|
|
assert not texwarp |
|
assert upconv == None |
|
|
|
self.primsize = primsize |
|
|
|
nlayers = int(math.log2(self.primsize / hstart)) |
|
|
|
lastch = chstart |
|
dims = (hstart, hstart, hstart) |
|
|
|
layers = [] |
|
layers.append(LinearELR(inch, chstart*dims[0]*dims[1]*dims[2], act=nn.LeakyReLU(0.2))) |
|
layers.append(Reshape(-1, chstart, dims[0], dims[1], dims[2])) |
|
|
|
for i in range(nlayers): |
|
nextch = lastch if i % 2 == 0 else lastch // 2 |
|
|
|
if i == nlayers - 2 and penultch is not None: |
|
nextch = penultch |
|
|
|
layers.append(ConvTranspose3dELR( |
|
lastch, |
|
(outch if i == nlayers - 1 else nextch), |
|
4, 2, 1, |
|
ub=(dims[0]*2, dims[1]*2, dims[2]*2) if ub else None, |
|
norm=None if i == nlayers - 1 else norm, |
|
act=None if i == nlayers - 1 else nn.LeakyReLU(0.2) |
|
)) |
|
|
|
lastch = nextch |
|
dims = (dims[0] * 2, dims[1] * 2, dims[2] * 2) |
|
|
|
self.mod = nn.Sequential(*layers) |
|
|
|
def forward(self, enc, renderoptions : Dict[str, str], trainiter : Optional[int]=None): |
|
x = self.mod(enc) |
|
|
|
algo = renderoptions.get("algo") |
|
chlast = renderoptions.get("chlast") |
|
|
|
if chlast is not None and bool(chlast): |
|
|
|
outch = x.size(1) |
|
x = x.permute(0, 2, 3, 4, 1)[:, None, :, :, :, :].contiguous() |
|
else: |
|
outch = x.size(1) |
|
x = x[:, None, :, :, :, :].contiguous() |
|
|
|
return x |
|
|
|
def get_dec(dectype, **kwargs): |
|
if dectype == "conv": |
|
return ContentDecoder(**kwargs) |
|
else: |
|
raise |
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, |
|
volradius, |
|
dectype="conv", |
|
primsize=128, |
|
chstart=256, |
|
penultch=None, |
|
condsize=0, |
|
warptype="conv", |
|
warpprimsize=32, |
|
sharedrgba=False, |
|
norm=None, |
|
mod=False, |
|
elr=True, |
|
notplateact=False, |
|
postrainstart=-1, |
|
alphatrainstart=-1, |
|
renderoptions={}, |
|
**kwargs): |
|
""" |
|
Parameters |
|
---------- |
|
volradius : float |
|
radius of bounding volume of scene |
|
dectype : string |
|
type of content decoder, options are "slab2d", "slab2d3d", "slab2d3dv2" |
|
primsize : Tuple[int, int, int] |
|
size of primitive dimensions |
|
postrainstart : int |
|
training iterations to start learning position, rotation, and |
|
scaling (i.e., primitives stay frozen until this iteration number) |
|
condsize : int |
|
unused |
|
motiontype : string |
|
motion model, options are "linear" and "deconv" |
|
warptype : string |
|
warp model, options are "same" to use same architecture as content |
|
or None |
|
sharedrgba : bool |
|
True to use 1 branch to output rgba, False to use 1 branch for rgb |
|
and 1 branch for alpha |
|
""" |
|
super(Decoder, self).__init__() |
|
|
|
self.volradius = volradius |
|
self.postrainstart = postrainstart |
|
self.alphatrainstart = alphatrainstart |
|
|
|
self.primsize = primsize |
|
self.warpprimsize = warpprimsize |
|
|
|
self.notplateact = notplateact |
|
|
|
self.enc = LinearELR(256 + condsize, 256) |
|
|
|
|
|
if sharedrgba: |
|
self.rgbadec = get_dec(dectype, primsize=primsize, |
|
inch=256+3, outch=4, norm=norm, mod=mod, elr=elr, |
|
penultch=penultch, **kwargs) |
|
|
|
if renderoptions.get("half", False): |
|
self.rgbadec = self.rgbadec.half() |
|
|
|
if renderoptions.get("chlastconv", False): |
|
self.rgbadec = self.rgbadec.to(memory_format=torch.channels_last) |
|
else: |
|
self.rgbdec = get_dec(dectype, primsize=primsize, |
|
inch=256+3, outch=3, chstart=chstart, norm=norm, mod=mod, |
|
elr=elr, penultch=penultch, **kwargs) |
|
self.alphadec = get_dec(dectype, primsize=primsize, |
|
inch=256, outch=1, chstart=chstart, norm=norm, mod=mod, |
|
elr=elr, penultch=penultch, **kwargs) |
|
self.rgbadec = None |
|
|
|
if renderoptions.get("half", False): |
|
self.rgbdec = self.rgbdec.half() |
|
self.alphadec = self.alphadec.half() |
|
|
|
if renderoptions.get("chlastconv", False): |
|
self.rgbdec = self.rgbdec.to(memory_format=torch.channels_last) |
|
self.alphadec = self.alphadec.to(memory_format=torch.channels_last) |
|
|
|
|
|
if warptype is not None: |
|
self.warpdec = get_dec(warptype, primsize=warpprimsize, |
|
inch=256, outch=3, chstart=chstart, norm=norm, mod=mod, elr=elr, **kwargs) |
|
else: |
|
self.warpdec = None |
|
|
|
def forward(self, |
|
encoding, |
|
viewpos, |
|
condinput : Optional[torch.Tensor]=None, |
|
renderoptions : Optional[Dict[str, str]]=None, |
|
trainiter : int=-1, |
|
evaliter : Optional[torch.Tensor]=None, |
|
losslist : Optional[List[str]]=None, |
|
modelmatrix : Optional[torch.Tensor]=None): |
|
""" |
|
Parameters |
|
---------- |
|
encoding : torch.Tensor [B, 256] |
|
Encoding of current frame |
|
viewpos : torch.Tensor [B, 3] |
|
Viewing position of target camera view |
|
condinput : torch.Tensor [B, ?] |
|
Additional conditioning input (e.g., headpose) |
|
renderoptions : dict |
|
Options for rendering (e.g., rendering debug images) |
|
trainiter : int, |
|
Current training iteration |
|
losslist : list, |
|
List of losses to compute and return |
|
|
|
Returns |
|
------- |
|
result : dict, |
|
Contains predicted vertex positions, primitive contents and |
|
locations, scaling, and orientation, and any losses. |
|
""" |
|
assert renderoptions is not None |
|
assert losslist is not None |
|
|
|
if condinput is not None: |
|
encoding = torch.cat([encoding, condinput], dim=1) |
|
|
|
encoding = self.enc(encoding) |
|
|
|
viewdirs = F.normalize(viewpos, dim=1) |
|
|
|
primpos = torch.zeros(encoding.size(0), 1, 3, device=encoding.device) |
|
primrot = torch.eye(3, device=encoding.device)[None, None, :, :].repeat(encoding.size(0), 1, 1, 1) |
|
primscale = torch.ones(encoding.size(0), 1, 3, device=encoding.device) |
|
|
|
|
|
algo = renderoptions.get("algo") |
|
chlast = renderoptions.get("chlast") |
|
half = renderoptions.get("half") |
|
|
|
if self.rgbadec is not None: |
|
|
|
scale = torch.tensor([25., 25., 25., 1.], device=encoding.device) |
|
bias = torch.tensor([100., 100., 100., 0.], device=encoding.device) |
|
if chlast is not None and bool(chlast): |
|
scale = scale[None, None, None, None, None, :] |
|
bias = bias[None, None, None, None, None, :] |
|
else: |
|
scale = scale[None, None, :, None, None, None] |
|
bias = bias[None, None, :, None, None, None] |
|
|
|
templatein = torch.cat([encoding, viewdirs], dim=1) |
|
if half is not None and bool(half): |
|
templatein = templatein.half() |
|
template = self.rgbadec(templatein, trainiter=trainiter, renderoptions=renderoptions) |
|
template = bias + scale * template |
|
if not self.notplateact: |
|
template = F.relu(template) |
|
if half is not None and bool(half): |
|
template = template.float() |
|
else: |
|
templatein = torch.cat([encoding, viewdirs], dim=1) |
|
if half is not None and bool(half): |
|
templatein = templatein.half() |
|
primrgb = self.rgbdec(templatein, trainiter=trainiter, renderoptions=renderoptions) |
|
primrgb = primrgb * 25. + 100. |
|
if not self.notplateact: |
|
primrgb = F.relu(primrgb) |
|
|
|
templatein = encoding |
|
if half is not None and bool(half): |
|
templatein = templatein.half() |
|
primalpha = self.alphadec(templatein, trainiter=trainiter, renderoptions=renderoptions) |
|
if not self.notplateact: |
|
primalpha = F.relu(primalpha) |
|
|
|
if trainiter <= self.alphatrainstart: |
|
primalpha = primalpha * 0. + 1. |
|
|
|
if algo is not None and int(algo) == 4: |
|
template = torch.cat([primrgb, primalpha], dim=-1) |
|
elif chlast is not None and bool(chlast): |
|
template = torch.cat([primrgb, primalpha], dim=-1) |
|
else: |
|
template = torch.cat([primrgb, primalpha], dim=2) |
|
if half is not None and bool(half): |
|
template = template.float() |
|
|
|
if self.warpdec is not None: |
|
warp = self.warpdec(encoding, trainiter=trainiter, renderoptions=renderoptions) * 0.01 |
|
warp = warp + torch.stack(torch.meshgrid( |
|
torch.linspace(-1., 1., self.warpprimsize, device=encoding.device), |
|
torch.linspace(-1., 1., self.warpprimsize, device=encoding.device), |
|
torch.linspace(-1., 1., self.warpprimsize, device=encoding.device))[::-1], |
|
dim=-1 if chlast is not None and bool(chlast) else 0)[None, None, :, :, :, :] |
|
warp = warp.contiguous() |
|
else: |
|
warp = None |
|
|
|
losses = {} |
|
|
|
|
|
if "primvolsum" in losslist: |
|
losses["primvolsum"] = torch.sum(torch.prod(1. / primscale, dim=-1), dim=-1) |
|
|
|
if "logprimscalevar" in losslist: |
|
logprimscale = torch.log(primscale) |
|
logprimscalemean = torch.mean(logprimscale, dim=1, keepdim=True) |
|
losses["logprimscalevar"] = torch.mean((logprimscale - logprimscalemean) ** 2) |
|
|
|
result = { |
|
"template": template, |
|
"primpos": primpos, |
|
"primrot": primrot, |
|
"primscale": primscale} |
|
if warp is not None: |
|
result["warp"] = warp |
|
return result, losses |
|
|