Spaces:
Running
on
Zero
Running
on
Zero
| # TATS | |
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): | |
| n_dims = len(x.shape) | |
| if src_dim < 0: | |
| src_dim = n_dims + src_dim | |
| if dest_dim < 0: | |
| dest_dim = n_dims + dest_dim | |
| assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims | |
| dims = list(range(n_dims)) | |
| del dims[src_dim] | |
| permutation = [] | |
| ctr = 0 | |
| for i in range(n_dims): | |
| if i == dest_dim: | |
| permutation.append(src_dim) | |
| else: | |
| permutation.append(dims[ctr]) | |
| ctr += 1 | |
| x = x.permute(permutation) | |
| if make_contiguous: | |
| x = x.contiguous() | |
| return x | |
| def silu(x): | |
| return x * torch.sigmoid(x) | |
| class SiLU(nn.Module): | |
| def __init__(self): | |
| super(SiLU, self).__init__() | |
| def forward(self, x): | |
| return silu(x) | |
| def hinge_d_loss(logits_real, logits_fake): | |
| loss_real = torch.mean(F.relu(1. - logits_real)) | |
| loss_fake = torch.mean(F.relu(1. + logits_fake)) | |
| d_loss = 0.5 * (loss_real + loss_fake) | |
| return d_loss | |
| def vanilla_d_loss(logits_real, logits_fake): | |
| d_loss = 0.5 * ( | |
| torch.mean(torch.nn.functional.softplus(-logits_real)) + | |
| torch.mean(torch.nn.functional.softplus(logits_fake))) | |
| return d_loss | |
| def Normalize(in_channels, norm_type='group'): | |
| assert norm_type in ['group', 'batch'] | |
| if norm_type == 'group': | |
| return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
| elif norm_type == 'batch': | |
| return torch.nn.SyncBatchNorm(in_channels) | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', | |
| padding_type='replicate'): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.use_conv_shortcut = conv_shortcut | |
| self.norm1 = Normalize(in_channels, norm_type) | |
| self.conv1 = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type) | |
| self.dropout = torch.nn.Dropout(dropout) | |
| self.norm2 = Normalize(in_channels, norm_type) | |
| self.conv2 = SamePadConv3d(out_channels, out_channels, kernel_size=3, padding_type=padding_type) | |
| if self.in_channels != self.out_channels: | |
| self.conv_shortcut = SamePadConv3d(in_channels, out_channels, kernel_size=3, padding_type=padding_type) | |
| def forward(self, x): | |
| h = x | |
| h = self.norm1(h) | |
| h = silu(h) | |
| h = self.conv1(h) | |
| h = self.norm2(h) | |
| h = silu(h) | |
| h = self.conv2(h) | |
| if self.in_channels != self.out_channels: | |
| x = self.conv_shortcut(x) | |
| return x + h | |
| # Does not support dilation | |
| class SamePadConv3d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): | |
| super().__init__() | |
| if isinstance(kernel_size, int): | |
| kernel_size = (kernel_size,) * 3 | |
| if isinstance(stride, int): | |
| stride = (stride,) * 3 | |
| # assumes that the input shape is divisible by stride | |
| total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) | |
| pad_input = [] | |
| for p in total_pad[::-1]: # reverse since F.pad starts from last dim | |
| pad_input.append((p // 2 + p % 2, p // 2)) | |
| pad_input = sum(pad_input, tuple()) | |
| self.pad_input = pad_input | |
| self.padding_type = padding_type | |
| self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, | |
| stride=stride, padding=0, bias=bias) | |
| self.weight = self.conv.weight | |
| def forward(self, x): | |
| return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) | |
| class SamePadConvTranspose3d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): | |
| super().__init__() | |
| if isinstance(kernel_size, int): | |
| kernel_size = (kernel_size,) * 3 | |
| if isinstance(stride, int): | |
| stride = (stride,) * 3 | |
| total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) | |
| pad_input = [] | |
| for p in total_pad[::-1]: # reverse since F.pad starts from last dim | |
| pad_input.append((p // 2 + p % 2, p // 2)) | |
| pad_input = sum(pad_input, tuple()) | |
| self.pad_input = pad_input | |
| self.padding_type = padding_type | |
| self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, | |
| stride=stride, bias=bias, | |
| padding=tuple([k - 1 for k in kernel_size])) | |
| def forward(self, x): | |
| return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) | |
| class Encoder(nn.Module): | |
| def __init__(self, n_hiddens, downsample, z_channels, image_channel=3, norm_type='group', | |
| padding_type='replicate', res_num=1): | |
| super().__init__() | |
| n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) | |
| self.conv_blocks = nn.ModuleList() | |
| max_ds = n_times_downsample.max() | |
| self.conv_first = SamePadConv3d(image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) | |
| for i in range(max_ds): | |
| block = nn.Module() | |
| in_channels = n_hiddens * 2 ** i | |
| out_channels = n_hiddens * 2 ** (i + 1) | |
| stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) | |
| stride = list(stride) | |
| stride[0] = 1 | |
| stride = tuple(stride) | |
| block.down = SamePadConv3d(in_channels, out_channels, 4, stride=stride, padding_type=padding_type) | |
| block.res = ResBlock(out_channels, out_channels, norm_type=norm_type) | |
| self.conv_blocks.append(block) | |
| n_times_downsample -= 1 | |
| self.final_block = nn.Sequential( | |
| Normalize(out_channels, norm_type), | |
| SiLU(), | |
| SamePadConv3d(out_channels, z_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding_type=padding_type) | |
| ) | |
| self.out_channels = out_channels | |
| def forward(self, x): | |
| h = self.conv_first(x) | |
| for block in self.conv_blocks: | |
| h = block.down(h) | |
| h = block.res(h) | |
| h = self.final_block(h) | |
| return h | |
| class Decoder(nn.Module): | |
| def __init__(self, n_hiddens, upsample, z_channels, image_channel, norm_type='group'): | |
| super().__init__() | |
| n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) | |
| max_us = n_times_upsample.max() | |
| in_channels = z_channels | |
| self.conv_blocks = nn.ModuleList() | |
| for i in range(max_us): | |
| block = nn.Module() | |
| in_channels = in_channels if i == 0 else n_hiddens * 2 ** (max_us - i + 1) | |
| out_channels = n_hiddens * 2 ** (max_us - i) | |
| us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) | |
| us = list(us) | |
| us[0] = 1 | |
| us = tuple(us) | |
| block.up = SamePadConvTranspose3d(in_channels, out_channels, 4, stride=us) | |
| block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type) | |
| block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type) | |
| self.conv_blocks.append(block) | |
| n_times_upsample -= 1 | |
| self.conv_out = SamePadConv3d(out_channels, image_channel, kernel_size=3) | |
| def forward(self, x): | |
| h = x | |
| for i, block in enumerate(self.conv_blocks): | |
| h = block.up(h) | |
| h = block.res1(h) | |
| h = block.res2(h) | |
| h = self.conv_out(h) | |
| return h | |
| # | |
| # unit test | |
| if __name__ == '__main__': | |
| encoder = Encoder(n_hiddens=320, downsample=[1, 2, 2, 2], z_channels=8, double_z=True, image_channel=96, | |
| norm_type='group', padding_type='replicate') | |
| encoder = encoder.cuda() | |
| en_input = torch.rand(1, 96, 3, 256, 256).cuda() | |
| out = encoder(en_input) | |
| print(out.shape) | |
| # mean, logvar = torch.chunk(out, 2, dim=1) | |
| # # print(mean.shape) | |
| # decoder = DecoderRe(n_hiddens=320, upsample=[2, 2, 2, 1], z_channels=8, image_channel=96, | |
| # norm_type='group' ) | |
| # | |
| # decoder = decoder.cuda() | |
| # out = decoder(mean) | |
| # print(out.shape) | |
| # # logvar = nn.Parameter(torch.ones(size=()) * 0.0) | |
| # # print(logvar) | |