刘虹雨
update
8ed2f16
# TATS
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
n_dims = len(x.shape)
if src_dim < 0:
src_dim = n_dims + src_dim
if dest_dim < 0:
dest_dim = n_dims + dest_dim
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
dims = list(range(n_dims))
del dims[src_dim]
permutation = []
ctr = 0
for i in range(n_dims):
if i == dest_dim:
permutation.append(src_dim)
else:
permutation.append(dims[ctr])
ctr += 1
x = x.permute(permutation)
if make_contiguous:
x = x.contiguous()
return x
def silu(x):
return x * torch.sigmoid(x)
class SiLU(nn.Module):
def __init__(self):
super(SiLU, self).__init__()
def forward(self, x):
return silu(x)
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1. - logits_real))
loss_fake = torch.mean(F.relu(1. + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real)) +
torch.mean(torch.nn.functional.softplus(logits_fake)))
return d_loss
def Normalize(in_channels, norm_type='group'):
assert norm_type in ['group', 'batch']
if norm_type == 'group':
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == 'batch':
return torch.nn.SyncBatchNorm(in_channels)
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group',
padding_type='replicate'):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, norm_type)
self.conv1 = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type)
self.dropout = torch.nn.Dropout(dropout)
self.norm2 = Normalize(in_channels, norm_type)
self.conv2 = SamePadConv3d(out_channels, out_channels, kernel_size=3, padding_type=padding_type)
if self.in_channels != self.out_channels:
self.conv_shortcut = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type)
def forward(self, x):
h = x
h = self.norm1(h)
h = silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.conv_shortcut(x)
return x + h
# Does not support dilation
class SamePadConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
# assumes that the input shape is divisible by stride
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.padding_type = padding_type
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
stride=stride, padding=0, bias=bias)
self.weight = self.conv.weight
def forward(self, x):
return self.conv(F.pad(x, self.pad_input, mode=self.padding_type))
class SamePadConvTranspose3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
pad_input = []
for p in total_pad[::-1]: # reverse since F.pad starts from last dim
pad_input.append((p // 2 + p % 2, p // 2))
pad_input = sum(pad_input, tuple())
self.pad_input = pad_input
self.padding_type = padding_type
self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
stride=stride, bias=bias,
padding=tuple([k - 1 for k in kernel_size]))
def forward(self, x):
return self.convt(F.pad(x, self.pad_input, mode=self.padding_type))
class Encoder(nn.Module):
def __init__(self, n_hiddens, downsample, z_channels, image_channel=3, norm_type='group',
padding_type='replicate', res_num=1):
super().__init__()
n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
self.conv_blocks = nn.ModuleList()
max_ds = n_times_downsample.max()
self.conv_first = SamePadConv3d(image_channel, n_hiddens, kernel_size=3, padding_type=padding_type)
for i in range(max_ds):
block = nn.Module()
in_channels = n_hiddens * 2 ** i
out_channels = n_hiddens * 2 ** (i + 1)
stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
stride = list(stride)
stride[0] = 1
stride = tuple(stride)
block.down = SamePadConv3d(in_channels, out_channels, 4, stride=stride, padding_type=padding_type)
block.res = ResBlock(out_channels, out_channels, norm_type=norm_type)
self.conv_blocks.append(block)
n_times_downsample -= 1
self.final_block = nn.Sequential(
Normalize(out_channels, norm_type),
SiLU(),
SamePadConv3d(out_channels, z_channels,
kernel_size=3,
stride=1,
padding_type=padding_type)
)
self.out_channels = out_channels
def forward(self, x):
h = self.conv_first(x)
for block in self.conv_blocks:
h = block.down(h)
h = block.res(h)
h = self.final_block(h)
return h
class Decoder(nn.Module):
def __init__(self, n_hiddens, upsample, z_channels, image_channel, norm_type='group'):
super().__init__()
n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
max_us = n_times_upsample.max()
in_channels = z_channels
self.conv_blocks = nn.ModuleList()
for i in range(max_us):
block = nn.Module()
in_channels = in_channels if i == 0 else n_hiddens * 2 ** (max_us - i + 1)
out_channels = n_hiddens * 2 ** (max_us - i)
us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
us = list(us)
us[0] = 1
us = tuple(us)
block.up = SamePadConvTranspose3d(in_channels, out_channels, 4, stride=us)
block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type)
block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type)
self.conv_blocks.append(block)
n_times_upsample -= 1
self.conv_out = SamePadConv3d(out_channels, image_channel, kernel_size=3)
def forward(self, x):
h = x
for i, block in enumerate(self.conv_blocks):
h = block.up(h)
h = block.res1(h)
h = block.res2(h)
h = self.conv_out(h)
return h
#
# unit test
if __name__ == '__main__':
encoder = Encoder(n_hiddens=320, downsample=[1, 2, 2, 2], z_channels=8, double_z=True, image_channel=96,
norm_type='group', padding_type='replicate')
encoder = encoder.cuda()
en_input = torch.rand(1, 96, 3, 256, 256).cuda()
out = encoder(en_input)
print(out.shape)
# mean, logvar = torch.chunk(out, 2, dim=1)
# # print(mean.shape)
# decoder = DecoderRe(n_hiddens=320, upsample=[2, 2, 2, 1], z_channels=8, image_channel=96,
# norm_type='group' )
#
# decoder = decoder.cuda()
# out = decoder(mean)
# print(out.shape)
# # logvar = nn.Parameter(torch.ones(size=()) * 0.0)
# # print(logvar)