# TATS # Copyright (c) Meta Platforms, Inc. All Rights Reserved import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): n_dims = len(x.shape) if src_dim < 0: src_dim = n_dims + src_dim if dest_dim < 0: dest_dim = n_dims + dest_dim assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims dims = list(range(n_dims)) del dims[src_dim] permutation = [] ctr = 0 for i in range(n_dims): if i == dest_dim: permutation.append(src_dim) else: permutation.append(dims[ctr]) ctr += 1 x = x.permute(permutation) if make_contiguous: x = x.contiguous() return x def silu(x): return x * torch.sigmoid(x) class SiLU(nn.Module): def __init__(self): super(SiLU, self).__init__() def forward(self, x): return silu(x) def hinge_d_loss(logits_real, logits_fake): loss_real = torch.mean(F.relu(1. - logits_real)) loss_fake = torch.mean(F.relu(1. + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def vanilla_d_loss(logits_real, logits_fake): d_loss = 0.5 * ( torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))) return d_loss def Normalize(in_channels, norm_type='group'): assert norm_type in ['group', 'batch'] if norm_type == 'group': return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) elif norm_type == 'batch': return torch.nn.SyncBatchNorm(in_channels) class ResBlock(nn.Module): def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate'): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels, norm_type) self.conv1 = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type) self.dropout = torch.nn.Dropout(dropout) self.norm2 = Normalize(in_channels, norm_type) self.conv2 = SamePadConv3d(out_channels, out_channels, kernel_size=3, padding_type=padding_type) if self.in_channels != self.out_channels: self.conv_shortcut = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type) def forward(self, x): h = x h = self.norm1(h) h = silu(h) h = self.conv1(h) h = self.norm2(h) h = silu(h) h = self.conv2(h) if self.in_channels != self.out_channels: x = self.conv_shortcut(x) return x + h # Does not support dilation class SamePadConv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size,) * 3 if isinstance(stride, int): stride = (stride,) * 3 # assumes that the input shape is divisible by stride total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) pad_input = [] for p in total_pad[::-1]: # reverse since F.pad starts from last dim pad_input.append((p // 2 + p % 2, p // 2)) pad_input = sum(pad_input, tuple()) self.pad_input = pad_input self.padding_type = padding_type self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias) self.weight = self.conv.weight def forward(self, x): return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) class SamePadConvTranspose3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size,) * 3 if isinstance(stride, int): stride = (stride,) * 3 total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) pad_input = [] for p in total_pad[::-1]: # reverse since F.pad starts from last dim pad_input.append((p // 2 + p % 2, p // 2)) pad_input = sum(pad_input, tuple()) self.pad_input = pad_input self.padding_type = padding_type self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, bias=bias, padding=tuple([k - 1 for k in kernel_size])) def forward(self, x): return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) class Encoder(nn.Module): def __init__(self, n_hiddens, downsample, z_channels, image_channel=3, norm_type='group', padding_type='replicate', res_num=1): super().__init__() n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) self.conv_blocks = nn.ModuleList() max_ds = n_times_downsample.max() self.conv_first = SamePadConv3d(image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) for i in range(max_ds): block = nn.Module() in_channels = n_hiddens * 2 ** i out_channels = n_hiddens * 2 ** (i + 1) stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) stride = list(stride) stride[0] = 1 stride = tuple(stride) block.down = SamePadConv3d(in_channels, out_channels, 4, stride=stride, padding_type=padding_type) block.res = ResBlock(out_channels, out_channels, norm_type=norm_type) self.conv_blocks.append(block) n_times_downsample -= 1 self.final_block = nn.Sequential( Normalize(out_channels, norm_type), SiLU(), SamePadConv3d(out_channels, z_channels, kernel_size=3, stride=1, padding_type=padding_type) ) self.out_channels = out_channels def forward(self, x): h = self.conv_first(x) for block in self.conv_blocks: h = block.down(h) h = block.res(h) h = self.final_block(h) return h class Decoder(nn.Module): def __init__(self, n_hiddens, upsample, z_channels, image_channel, norm_type='group'): super().__init__() n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) max_us = n_times_upsample.max() in_channels = z_channels self.conv_blocks = nn.ModuleList() for i in range(max_us): block = nn.Module() in_channels = in_channels if i == 0 else n_hiddens * 2 ** (max_us - i + 1) out_channels = n_hiddens * 2 ** (max_us - i) us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) us = list(us) us[0] = 1 us = tuple(us) block.up = SamePadConvTranspose3d(in_channels, out_channels, 4, stride=us) block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type) block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type) self.conv_blocks.append(block) n_times_upsample -= 1 self.conv_out = SamePadConv3d(out_channels, image_channel, kernel_size=3) def forward(self, x): h = x for i, block in enumerate(self.conv_blocks): h = block.up(h) h = block.res1(h) h = block.res2(h) h = self.conv_out(h) return h # # unit test if __name__ == '__main__': encoder = Encoder(n_hiddens=320, downsample=[1, 2, 2, 2], z_channels=8, double_z=True, image_channel=96, norm_type='group', padding_type='replicate') encoder = encoder.cuda() en_input = torch.rand(1, 96, 3, 256, 256).cuda() out = encoder(en_input) print(out.shape) # mean, logvar = torch.chunk(out, 2, dim=1) # # print(mean.shape) # decoder = DecoderRe(n_hiddens=320, upsample=[2, 2, 2, 1], z_channels=8, image_channel=96, # norm_type='group' ) # # decoder = decoder.cuda() # out = decoder(mean) # print(out.shape) # # logvar = nn.Parameter(torch.ones(size=()) * 0.0) # # print(logvar)