|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, List |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from models.utils import LinearELR, Conv2dELR, Downsample2d |
|
|
|
class Encoder(torch.nn.Module): |
|
def __init__(self, ninputs, size, nlayers=7, conv=Conv2dELR, lin=LinearELR): |
|
super(Encoder, self).__init__() |
|
|
|
self.ninputs = ninputs |
|
height, width = size |
|
self.nlayers = nlayers |
|
|
|
ypad = ((height + 2 ** nlayers - 1) // 2 ** nlayers) * 2 ** nlayers - height |
|
xpad = ((width + 2 ** nlayers - 1) // 2 ** nlayers) * 2 ** nlayers - width |
|
self.pad = nn.ZeroPad2d((xpad // 2, xpad - xpad // 2, ypad // 2, ypad - ypad // 2)) |
|
|
|
self.downwidth = ((width + 2 ** nlayers - 1) // 2 ** nlayers) |
|
self.downheight = ((height + 2 ** nlayers - 1) // 2 ** nlayers) |
|
|
|
|
|
layers = [] |
|
inch, outch = 3, 64 |
|
for i in range(nlayers): |
|
layers.append(conv(inch, outch, 4, 2, 1, norm="demod", act=nn.LeakyReLU(0.2))) |
|
|
|
if inch == outch: |
|
outch = inch * 2 |
|
else: |
|
inch = outch |
|
if outch > 256: |
|
outch = 256 |
|
|
|
self.down1 = nn.ModuleList([nn.Sequential(*layers) |
|
for i in range(self.ninputs)]) |
|
self.down2 = lin(256 * self.ninputs * self.downwidth * self.downheight, 512, norm="demod", act=nn.LeakyReLU(0.2)) |
|
self.mu = lin(512, 256) |
|
self.logstd = lin(512, 256) |
|
|
|
def forward(self, x, losslist : Optional[List[str]]=None): |
|
assert losslist is not None |
|
|
|
x = self.pad(x) |
|
x = [self.down1[i](x[:, i*3:(i+1)*3, :, :]).view(x.size(0), 256 * self.downwidth * self.downheight) |
|
for i in range(self.ninputs)] |
|
x = torch.cat(x, dim=1) |
|
x = self.down2(x) |
|
|
|
mu, logstd = self.mu(x) * 0.1, self.logstd(x) * 0.01 |
|
if self.training: |
|
z = mu + torch.exp(logstd) * torch.randn(*logstd.size(), device=logstd.device) |
|
else: |
|
z = mu |
|
|
|
losses = {} |
|
if "kldiv" in losslist: |
|
losses["kldiv"] = torch.mean(-0.5 - logstd + 0.5 * mu ** 2 + 0.5 * torch.exp(2 * logstd), dim=-1) |
|
|
|
return {"encoding": z}, losses |
|
|