|
import numpy as np |
|
from functools import partial |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
from utils.typing import * |
|
from .attention import MemEffAttention |
|
|
|
class VolumeAttention(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
num_heads: int = 8, |
|
qkv_bias: bool = False, |
|
proj_bias: bool = True, |
|
attn_drop: float = 0.0, |
|
proj_drop: float = 0.0, |
|
groups: int = 32, |
|
eps: float = 1e-5, |
|
residual: bool = True, |
|
skip_scale: float = 1, |
|
): |
|
super().__init__() |
|
|
|
self.residual = residual |
|
self.skip_scale = skip_scale |
|
|
|
self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True) |
|
self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop) |
|
|
|
def forward(self, x): |
|
|
|
B, C, H, W, D = x.shape |
|
|
|
res = x |
|
x = self.norm(x) |
|
|
|
x = x.permute(0, 2, 3, 4, 1).reshape(B, -1, C) |
|
x = self.attn(x) |
|
x = x.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3).reshape(B, C, H, W, D) |
|
|
|
if self.residual: |
|
x = (x + res) * self.skip_scale |
|
|
|
return x |
|
|
|
class DiagonalGaussianDistribution: |
|
def __init__(self, parameters, deterministic=False): |
|
|
|
self.parameters = parameters |
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) |
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0) |
|
self.deterministic = deterministic |
|
self.std = torch.exp(0.5 * self.logvar) |
|
self.var = torch.exp(self.logvar) |
|
if self.deterministic: |
|
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype) |
|
|
|
def sample(self): |
|
sample = torch.randn(self.mean.shape, device=self.parameters.device, dtype=self.parameters.dtype) |
|
x = self.mean + self.std * sample |
|
return x |
|
|
|
def kl(self, other=None, dims=[1, 2, 3, 4]): |
|
if self.deterministic: |
|
return torch.Tensor([0.0]) |
|
else: |
|
if other is None: |
|
return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims) |
|
else: |
|
return 0.5 * torch.mean( |
|
torch.pow(self.mean - other.mean, 2) / other.var |
|
+ self.var / other.var |
|
- 1.0 |
|
- self.logvar |
|
+ other.logvar, |
|
dim=dims, |
|
) |
|
|
|
def nll(self, sample, dims=[1, 2, 3, 4]): |
|
if self.deterministic: |
|
return torch.Tensor([0.0]) |
|
logtwopi = np.log(2.0 * np.pi) |
|
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) |
|
|
|
def mode(self): |
|
return self.mean |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
resample: Literal['default', 'up', 'down'] = 'default', |
|
groups: int = 32, |
|
eps: float = 1e-5, |
|
skip_scale: float = 1, |
|
): |
|
super().__init__() |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.skip_scale = skip_scale |
|
|
|
self.norm1 = nn.GroupNorm(num_groups=min(groups, in_channels), num_channels=in_channels, eps=eps, affine=True) |
|
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
|
|
self.norm2 = nn.GroupNorm(num_groups=min(groups, out_channels), num_channels=out_channels, eps=eps, affine=True) |
|
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
|
|
|
self.act = F.silu |
|
|
|
self.resample = None |
|
if resample == 'up': |
|
self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest") |
|
elif resample == 'down': |
|
self.resample = nn.AvgPool3d(kernel_size=2, stride=2) |
|
|
|
self.shortcut = nn.Identity() |
|
if self.in_channels != self.out_channels: |
|
self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=True) |
|
|
|
|
|
def forward(self, x): |
|
res = x |
|
|
|
x = self.norm1(x) |
|
x = self.act(x) |
|
|
|
if self.resample: |
|
res = self.resample(res) |
|
x = self.resample(x) |
|
|
|
x = self.conv1(x) |
|
x = self.norm2(x) |
|
x = self.act(x) |
|
x = self.conv2(x) |
|
|
|
x = (x + self.shortcut(res)) * self.skip_scale |
|
|
|
return x |
|
|
|
class DownBlock(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
num_layers: int = 1, |
|
downsample: bool = True, |
|
skip_scale: float = 1, |
|
gradient_checkpointing: bool = False, |
|
): |
|
super().__init__() |
|
self.gradient_checkpointing = gradient_checkpointing |
|
|
|
nets = [] |
|
for i in range(num_layers): |
|
cin = in_channels if i == 0 else out_channels |
|
nets.append(ResnetBlock(cin, out_channels, skip_scale=skip_scale)) |
|
self.nets = nn.ModuleList(nets) |
|
|
|
self.downsample = None |
|
if downsample: |
|
self.downsample = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) |
|
|
|
def forward(self, x): |
|
if self.training and self.gradient_checkpointing: |
|
return checkpoint(self._forward, x, use_reentrant=False) |
|
else: |
|
return self._forward(x) |
|
|
|
def _forward(self, x): |
|
|
|
for net in self.nets: |
|
x = net(x) |
|
|
|
if self.downsample: |
|
x = self.downsample(x) |
|
|
|
return x |
|
|
|
|
|
class MidBlock(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
num_layers: int = 1, |
|
attention: bool = True, |
|
attention_heads: int = 8, |
|
skip_scale: float = 1, |
|
gradient_checkpointing: bool = False, |
|
): |
|
super().__init__() |
|
self.gradient_checkpointing = gradient_checkpointing |
|
|
|
nets = [] |
|
attns = [] |
|
|
|
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) |
|
|
|
for i in range(num_layers): |
|
nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) |
|
if attention: |
|
attns.append(VolumeAttention(in_channels, attention_heads, skip_scale=skip_scale)) |
|
else: |
|
attns.append(None) |
|
self.nets = nn.ModuleList(nets) |
|
self.attns = nn.ModuleList(attns) |
|
|
|
def forward(self, x): |
|
if self.training and self.gradient_checkpointing: |
|
return checkpoint(self._forward, x, use_reentrant=False) |
|
else: |
|
return self._forward(x) |
|
|
|
def _forward(self, x): |
|
x = self.nets[0](x) |
|
for attn, net in zip(self.attns, self.nets[1:]): |
|
if attn: |
|
x = attn(x) |
|
x = net(x) |
|
return x |
|
|
|
|
|
class UpBlock(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
num_layers: int = 1, |
|
upsample: bool = True, |
|
skip_scale: float = 1, |
|
gradient_checkpointing: bool = False, |
|
): |
|
super().__init__() |
|
self.gradient_checkpointing = gradient_checkpointing |
|
|
|
nets = [] |
|
for i in range(num_layers): |
|
cin = in_channels if i == 0 else out_channels |
|
nets.append(ResnetBlock(cin, out_channels, skip_scale=skip_scale)) |
|
|
|
self.nets = nn.ModuleList(nets) |
|
|
|
self.upsample = None |
|
if upsample: |
|
self.upsample = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2) |
|
|
|
def forward(self, x): |
|
if self.training and self.gradient_checkpointing: |
|
return checkpoint(self._forward, x, use_reentrant=False) |
|
else: |
|
return self._forward(x) |
|
|
|
def _forward(self, x): |
|
|
|
for net in self.nets: |
|
x = net(x) |
|
|
|
if self.upsample: |
|
x = self.upsample(x) |
|
|
|
return x |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int = 1, |
|
out_channels: int = 2 * 16, |
|
down_channels: Tuple[int, ...] = (8, 16, 32, 64), |
|
mid_attention: bool = True, |
|
layers_per_block: int = 2, |
|
skip_scale: float = np.sqrt(0.5), |
|
gradient_checkpointing: bool = False, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.conv_in = nn.Conv3d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
down_blocks = [] |
|
cout = down_channels[0] |
|
for i in range(len(down_channels)): |
|
cin = cout |
|
cout = down_channels[i] |
|
|
|
down_blocks.append(DownBlock( |
|
cin, cout, |
|
num_layers=layers_per_block, |
|
downsample=(i != len(down_channels) - 1), |
|
skip_scale=skip_scale, |
|
gradient_checkpointing=gradient_checkpointing, |
|
)) |
|
self.down_blocks = nn.ModuleList(down_blocks) |
|
|
|
|
|
self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale) |
|
|
|
|
|
self.norm_out = nn.GroupNorm(num_channels=down_channels[-1], num_groups=32, eps=1e-5) |
|
self.conv_out = nn.Conv3d(down_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
x = self.conv_in(x) |
|
|
|
|
|
for block in self.down_blocks: |
|
x = block(x) |
|
|
|
|
|
x = self.mid_block(x) |
|
|
|
|
|
x = self.norm_out(x) |
|
x = F.silu(x) |
|
x = self.conv_out(x) |
|
|
|
return x |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int = 16, |
|
out_channels: int = 1, |
|
up_channels: Tuple[int, ...] = (64, 32, 16, 8), |
|
mid_attention: bool = True, |
|
layers_per_block: int = 2, |
|
skip_scale: float = np.sqrt(0.5), |
|
gradient_checkpointing: bool = False, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.conv_in = nn.Conv3d(in_channels, up_channels[0], kernel_size=3, stride=1, padding=1) |
|
|
|
|
|
self.mid_block = MidBlock(up_channels[0], attention=mid_attention, skip_scale=skip_scale) |
|
|
|
|
|
up_blocks = [] |
|
cout = up_channels[0] |
|
for i in range(len(up_channels)): |
|
cin = cout |
|
cout = up_channels[i] |
|
|
|
up_blocks.append(UpBlock( |
|
cin, cout, |
|
num_layers=layers_per_block, |
|
upsample=(i != len(up_channels) - 1), |
|
skip_scale=skip_scale, |
|
gradient_checkpointing=gradient_checkpointing, |
|
)) |
|
self.up_blocks = nn.ModuleList(up_blocks) |
|
|
|
|
|
self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=min(32, up_channels[-1]), eps=1e-5) |
|
self.conv_out = nn.ConvTranspose3d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
x = self.conv_in(x) |
|
|
|
|
|
x = self.mid_block(x) |
|
|
|
|
|
for block in self.up_blocks: |
|
x = block(x) |
|
|
|
|
|
x = self.norm_out(x) |
|
x = F.silu(x) |
|
x = self.conv_out(x) |
|
|
|
return x |
|
|
|
|
|
class VAE(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int = 1, |
|
latent_channels: int = 16, |
|
out_channels: int = 1, |
|
down_channels: Tuple[int, ...] = (16, 32, 64, 128, 256), |
|
mid_attention: bool = True, |
|
up_channels: Tuple[int, ...] = (256, 128, 64, 32, 16), |
|
layers_per_block: int = 2, |
|
skip_scale: float = np.sqrt(0.5), |
|
gradient_checkpointing: bool = False, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.encoder = Encoder( |
|
in_channels=in_channels, |
|
out_channels=2 * latent_channels, |
|
down_channels=down_channels, |
|
mid_attention=mid_attention, |
|
layers_per_block=layers_per_block, |
|
skip_scale=skip_scale, |
|
gradient_checkpointing=gradient_checkpointing, |
|
) |
|
|
|
|
|
self.decoder = Decoder( |
|
in_channels=latent_channels, |
|
out_channels=out_channels, |
|
up_channels=up_channels, |
|
mid_attention=mid_attention, |
|
layers_per_block=layers_per_block, |
|
skip_scale=skip_scale, |
|
gradient_checkpointing=gradient_checkpointing, |
|
) |
|
|
|
|
|
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) |
|
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) |
|
|
|
def encode(self, x): |
|
x = self.encoder(x) |
|
x = self.quant_conv(x) |
|
posterior = DiagonalGaussianDistribution(x) |
|
return posterior |
|
|
|
def decode(self, x): |
|
x = self.post_quant_conv(x) |
|
x = self.decoder(x) |
|
return x |
|
|
|
def forward(self, x, sample=True): |
|
|
|
|
|
p = self.encode(x) |
|
|
|
if sample: |
|
z = p.sample() |
|
else: |
|
z = p.mode() |
|
|
|
x = self.decode(z) |
|
|
|
return x, p |