Spaces:
Running
on
Zero
Running
on
Zero
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
# | |
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
# property and proprietary rights in and to this material, related | |
# documentation and any modifications thereto. Any use, reproduction, | |
# disclosure or distribution of this material and related documentation | |
# without an express license agreement from NVIDIA CORPORATION or | |
# its affiliates is strictly prohibited. | |
"""Discriminator architectures from the paper | |
"Efficient Geometry-aware 3D Generative Adversarial Networks".""" | |
import numpy as np | |
import torch | |
from torch_utils import persistence, misc | |
from torch_utils.ops import upfirdn2d | |
from recon.models.stylegannext3D.networks_stylegan2 import DiscriminatorBlock, MappingNetwork, DiscriminatorEpilogue | |
from einops import rearrange | |
class SingleDiscriminator(torch.nn.Module): | |
def __init__(self, | |
c_dim, # Conditioning label (C) dimensionality. | |
img_resolution, # Input resolution. | |
img_channels, # Number of input color channels. | |
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | |
channel_base = 32768, # Overall multiplier for the number of channels. | |
channel_max = 512, # Maximum number of channels in any layer. | |
num_fp16_res = 4, # Use FP16 for the N highest resolutions. | |
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. | |
sr_upsample_factor = 1, # Ignored for SingleDiscriminator | |
block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. | |
): | |
super().__init__() | |
self.c_dim = c_dim | |
self.img_resolution = img_resolution | |
self.img_resolution_log2 = int(np.log2(img_resolution)) | |
self.img_channels = img_channels | |
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | |
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
if cmap_dim is None: | |
cmap_dim = channels_dict[4] | |
if c_dim == 0: | |
cmap_dim = 0 | |
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
cur_layer_idx = 0 | |
for res in self.block_resolutions: | |
in_channels = channels_dict[res] if res < img_resolution else 0 | |
tmp_channels = channels_dict[res] | |
out_channels = channels_dict[res // 2] | |
use_fp16 = (res >= fp16_resolution) | |
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
setattr(self, f'b{res}', block) | |
cur_layer_idx += block.num_layers | |
if c_dim > 0: | |
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) | |
def forward(self, img, c, update_emas=False, **block_kwargs): | |
img = img['image'] | |
_ = update_emas # unused | |
x = None | |
for res in self.block_resolutions: | |
block = getattr(self, f'b{res}') | |
x, img = block(x, img, **block_kwargs) | |
cmap = None | |
if self.c_dim > 0: | |
cmap = self.mapping(None, c) | |
x = self.b4(x, img, cmap) | |
return x | |
def extra_repr(self): | |
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' | |
#---------------------------------------------------------------------------- | |
def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'): | |
if filter_mode == 'antialiased': | |
ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True) | |
elif filter_mode == 'classic': | |
ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2) | |
ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False) | |
ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1) | |
elif filter_mode == 'none': | |
ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False) | |
elif type(filter_mode) == float: | |
assert 0 < filter_mode < 1 | |
filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True) | |
aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=False) | |
ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered | |
return ada_filtered_64 | |
#---------------------------------------------------------------------------- | |
class DualDiscriminator(torch.nn.Module): | |
def __init__(self, | |
c_dim, # Conditioning label (C) dimensionality. | |
img_resolution, # Input resolution. | |
img_channels, # Number of input color channels. | |
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | |
channel_base = 32768, # Overall multiplier for the number of channels. | |
channel_max = 512, # Maximum number of channels in any layer. | |
num_fp16_res = 4, # Use FP16 for the N highest resolutions. | |
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. | |
disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning. | |
block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. | |
): | |
super().__init__() | |
# img_channels *= 3 # 额外使用lms_contour作为condition | |
self.c_dim = c_dim | |
self.img_resolution = img_resolution | |
self.img_resolution_log2 = int(np.log2(img_resolution)) | |
self.img_channels = img_channels | |
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | |
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
if cmap_dim is None: | |
cmap_dim = channels_dict[4] | |
if c_dim == 0: | |
cmap_dim = 0 | |
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
cur_layer_idx = 0 | |
for res in self.block_resolutions: | |
in_channels = channels_dict[res] if res < img_resolution else 0 | |
tmp_channels = channels_dict[res] | |
out_channels = channels_dict[res // 2] | |
use_fp16 = (res >= fp16_resolution) | |
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
setattr(self, f'b{res}', block) | |
cur_layer_idx += block.num_layers | |
if c_dim > 0: | |
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) | |
self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) | |
self.disc_c_noise = disc_c_noise | |
def forward(self, img, c, update_emas=False, out_layers=None, **block_kwargs): | |
image_raw = filtered_resizing(img['image'], size=img['image_sr'].shape[-1], f=self.resample_filter) | |
img = torch.cat([img['image_sr'], image_raw], 1) | |
_ = update_emas # unused | |
x = None | |
output, count = [], 0 | |
for res in self.block_resolutions: | |
block = getattr(self, f'b{res}') | |
x, img = block(x, img, **block_kwargs) | |
if out_layers is not None and count in out_layers: | |
output.append(x) | |
count += 1 | |
if c is None: | |
return output | |
cmap = None | |
if self.c_dim > 0: | |
c = c[:, :self.c_dim] | |
if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise | |
cmap = self.mapping(None, c) | |
x = self.b4(x, img, cmap) | |
return x if out_layers is None else [x, output] | |
def extra_repr(self): | |
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' | |
#---------------------------------------------------------------------------- | |
class DummyDualDiscriminator(torch.nn.Module): | |
def __init__(self, | |
c_dim, # Conditioning label (C) dimensionality. | |
img_resolution, # Input resolution. | |
img_channels, # Number of input color channels. | |
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | |
channel_base = 32768, # Overall multiplier for the number of channels. | |
channel_max = 512, # Maximum number of channels in any layer. | |
num_fp16_res = 4, # Use FP16 for the N highest resolutions. | |
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. | |
block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. | |
): | |
super().__init__() | |
img_channels *= 2 | |
self.c_dim = c_dim | |
self.img_resolution = img_resolution | |
self.img_resolution_log2 = int(np.log2(img_resolution)) | |
self.img_channels = img_channels | |
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | |
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
if cmap_dim is None: | |
cmap_dim = channels_dict[4] | |
if c_dim == 0: | |
cmap_dim = 0 | |
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
cur_layer_idx = 0 | |
for res in self.block_resolutions: | |
in_channels = channels_dict[res] if res < img_resolution else 0 | |
tmp_channels = channels_dict[res] | |
out_channels = channels_dict[res // 2] | |
use_fp16 = (res >= fp16_resolution) | |
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
setattr(self, f'b{res}', block) | |
cur_layer_idx += block.num_layers | |
if c_dim > 0: | |
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) | |
self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) | |
self.raw_fade = 1 | |
def forward(self, img, c, update_emas=False, **block_kwargs): | |
self.raw_fade = max(0, self.raw_fade - 1/(500000/32)) | |
image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade | |
img = torch.cat([img['image'], image_raw], 1) | |
_ = update_emas # unused | |
x = None | |
for res in self.block_resolutions: | |
block = getattr(self, f'b{res}') | |
x, img = block(x, img, **block_kwargs) | |
cmap = None | |
if self.c_dim > 0: | |
cmap = self.mapping(None, c) | |
x = self.b4(x, img, cmap) | |
return x | |
def extra_repr(self): | |
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' | |
#---------------------------------------------------------------------------- | |
class VideoDiscriminator(torch.nn.Module): | |
def __init__(self, | |
c_dim, # Conditioning label (C) dimensionality. | |
img_resolution, # Input resolution. | |
img_channels, # Number of input color channels. | |
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | |
channel_base = 32768, # Overall multiplier for the number of channels. | |
channel_max = 512, # Maximum number of channels in any layer. | |
num_fp16_res = 4, # Use FP16 for the N highest resolutions. | |
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. | |
disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning. | |
block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. | |
): | |
super().__init__() | |
img_channels = img_channels * 2 + 1 | |
# c_dim = c_dim * 2 | |
self.c_dim = c_dim | |
self.img_resolution = img_resolution | |
self.img_resolution_log2 = int(np.log2(img_resolution)) | |
self.img_channels = img_channels | |
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | |
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
if cmap_dim is None: | |
cmap_dim = channels_dict[4] | |
if c_dim == 0: | |
cmap_dim = 0 | |
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
cur_layer_idx = 0 | |
for res in self.block_resolutions: | |
in_channels = channels_dict[res] if res < img_resolution else 0 | |
tmp_channels = channels_dict[res] | |
out_channels = channels_dict[res // 2] | |
use_fp16 = (res >= fp16_resolution) | |
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
setattr(self, f'b{res}', block) | |
cur_layer_idx += block.num_layers | |
if c_dim > 0: | |
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim * 2, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) | |
self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) | |
self.disc_c_noise = disc_c_noise | |
def forward(self, img, Ts, c, return_input=False, update_emas=False, **block_kwargs): | |
# image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) | |
# NOTE: concat images and timestamp here | |
# Ts: [num_vid, 2] or [num_vid, 2, C] | |
# img:[num_vid*2, C, H, W] | |
# c: [num_vid, 2, C] | |
timesteps = Ts.shape[1] | |
_, _, h, w = img['image'].shape | |
img = torch.cat([rearrange(img['image'], "(b t) c h w -> b (t c) h w", t=timesteps), (Ts[:, 1]-Ts[:, 0]).view(-1, 1, 1, 1).repeat(1, 1, h, w)], 1) | |
# img = rearrange(img['image'], "(b t) c h w -> b (t c) h w", t=timesteps) | |
if not self.img_resolution == h: | |
img = torch.nn.functional.interpolate(img, size=(self.img_resolution, self.img_resolution), mode='bilinear', align_corners=False, antialias=True) | |
if return_input: | |
inp = img | |
c = c[..., :self.c_dim] | |
# NOTE: reshape cameras | |
# c = rearrange(c, "(b t) c -> b (t c)", t=timesteps) | |
c = rearrange(c, "b t c -> b (t c)", t=timesteps) | |
# c = torch.cat([c, Ts[:, 1]-Ts[:, 0]], dim=-1) | |
_ = update_emas # unused | |
x = None | |
for res in self.block_resolutions: | |
block = getattr(self, f'b{res}') | |
x, img = block(x, img, **block_kwargs) | |
cmap = None | |
if self.c_dim > 0: | |
if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise | |
cmap = self.mapping(None, c) | |
x = self.b4(x, img, cmap) | |
if return_input: | |
return x, inp | |
return x, None | |
def extra_repr(self): | |
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' | |
#---------------------------------------------------------------------------- | |
class FusionVideoDiscriminator(torch.nn.Module): | |
def __init__(self, | |
c_dim, # Conditioning label (C) dimensionality. | |
img_resolution, # Input resolution. | |
img_channels, # Number of input color channels. | |
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. | |
channel_base = 32768, # Overall multiplier for the number of channels. | |
channel_max = 512, # Maximum number of channels in any layer. | |
num_fp16_res = 4, # Use FP16 for the N highest resolutions. | |
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. | |
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. | |
disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning. | |
block_kwargs = {}, # Arguments for DiscriminatorBlock. | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. | |
): | |
super().__init__() | |
# img_channels = img_channels * 2 + 1 | |
# c_dim = c_dim * 2 | |
self.c_dim = c_dim | |
self.img_resolution = img_resolution | |
self.img_resolution_log2 = int(np.log2(img_resolution)) | |
self.img_channels = img_channels | |
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] | |
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} | |
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) | |
if cmap_dim is None: | |
cmap_dim = channels_dict[4] | |
if c_dim == 0: | |
cmap_dim = 0 | |
self.concat_res = 16 | |
self.num_frames_div_factor = 2 | |
self.num_frames_per_video = 2 | |
self.time_encoder = TemporalDifferenceEncoder(max_num_frames=32, num_frames_per_video=self.num_frames_per_video) | |
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) | |
cur_layer_idx = 0 | |
for res in self.block_resolutions: | |
in_channels = channels_dict[res] if res < img_resolution else 0 | |
tmp_channels = channels_dict[res] | |
out_channels = channels_dict[res // 2] | |
if res // 2 == self.concat_res: | |
out_channels = out_channels // self.num_frames_div_factor | |
if res == self.concat_res: | |
in_channels = (in_channels // self.num_frames_div_factor) * self.num_frames_per_video | |
use_fp16 = (res >= fp16_resolution) | |
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, | |
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) | |
setattr(self, f'b{res}', block) | |
cur_layer_idx += block.num_layers | |
if c_dim > 0: | |
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim * self.num_frames_per_video + self.time_encoder.get_dim(), w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) | |
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) | |
self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) | |
self.disc_c_noise = disc_c_noise | |
def forward(self, img, Ts, c, return_input=False, update_emas=False, **block_kwargs): | |
# image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) | |
# NOTE: concat images and timestamp here | |
# Ts: [num_vid, 2] or [num_vid, 2, C] | |
# img:[num_vid*2, C, H, W] | |
# c: [num_vid, 2, C] | |
timesteps = Ts.shape[1] | |
_, _, h, w = img['image'].shape | |
assert len(img['image']) == Ts.shape[0] * Ts.shape[1] | |
img = img['image'] | |
# img = torch.cat([rearrange(img['image'], "(b t) c h w -> b (t c) h w", t=timesteps), (Ts[:, 1]-Ts[:, 0]).view(-1, 1, 1, 1).repeat(1, 1, h, w)], 1) | |
# img = rearrange(img['image'], "(b t) c h w -> b (t c) h w", t=timesteps) | |
if not self.img_resolution == h: | |
img = torch.nn.functional.interpolate(img, size=(self.img_resolution, self.img_resolution), mode='bilinear', align_corners=False, antialias=True) | |
if return_input: | |
inp = img | |
c = c[..., :self.c_dim] | |
# NOTE: reshape cameras | |
c = rearrange(c, "b t c -> b (t c)", t=timesteps) | |
# c = torch.cat([c, Ts[:, 1]-Ts[:, 0]], dim=-1) | |
# Encoding the time distances | |
t_embs = self.time_encoder(Ts) # [batch_size, t_dim] | |
c = torch.cat([c, t_embs], dim=1) # [batch_size, c_dim + t_dim] | |
_ = update_emas # unused | |
x = None | |
for res in self.block_resolutions: | |
block = getattr(self, f'b{res}') | |
if res == self.concat_res: | |
# Concatenating the frames | |
x = rearrange(x, "(b t) c h w -> b t c h w", t=timesteps) | |
# x = x.view(-1, self.num_frames_per_video, *x.shape[1:]) # [batch_size, num_frames, c, h, w] | |
x = x. view(x.shape[0], -1, *x.shape[3:]) # [batch_size, num_frames * c, h, w] | |
x, img = block(x, img, **block_kwargs) | |
cmap = None | |
if self.c_dim > 0: | |
if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise | |
cmap = self.mapping(None, c) | |
x = self.b4(x, img, cmap) | |
if return_input: | |
return x, inp | |
return x, None | |
def extra_repr(self): | |
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' | |
#---------------------------------------------------------------------------- | |
class FixedTimeEncoder(torch.nn.Module): | |
def __init__(self, | |
max_num_frames: int, # Maximum T size | |
skip_small_t_freqs: int=0, # How many high frequencies we should skip | |
): | |
super().__init__() | |
assert max_num_frames >= 1, f"Wrong max_num_frames: {max_num_frames}" | |
fourier_coefs = construct_log_spaced_freqs(max_num_frames, skip_small_t_freqs=skip_small_t_freqs) | |
self.register_buffer('fourier_coefs', fourier_coefs) # [1, num_fourier_feats] | |
def get_dim(self) -> int: | |
return self.fourier_coefs.shape[1] * 2 | |
def forward(self, t: torch.Tensor) -> torch.Tensor: | |
assert t.ndim == 2, f"Wrong shape: {t.shape}" | |
t = t.view(-1).float() # [batch_size * num_frames] | |
fourier_raw_embs = self.fourier_coefs * t.unsqueeze(1) # [bf, num_fourier_feats] | |
fourier_embs = torch.cat([ | |
fourier_raw_embs.sin(), | |
fourier_raw_embs.cos(), | |
], dim=1) # [bf, num_fourier_feats * 2] | |
return fourier_embs | |
#---------------------------------------------------------------------------- | |
class TemporalDifferenceEncoder(torch.nn.Module): | |
def __init__(self, max_num_frames, num_frames_per_video, sampling_type='random'): | |
super().__init__() | |
self.num_frames_per_video = num_frames_per_video | |
self.sampling_type = sampling_type | |
if self.num_frames_per_video > 1: | |
self.d = 256 | |
self.const_embed = torch.nn.Embedding(max_num_frames, self.d) | |
self.time_encoder = FixedTimeEncoder( | |
max_num_frames, | |
skip_small_t_freqs=0) | |
def get_dim(self) -> int: | |
if self.num_frames_per_video == 1: | |
return 1 | |
else: | |
if self.sampling_type == 'uniform': | |
return self.d + self.time_encoder.get_dim() | |
else: | |
return (self.d + self.time_encoder.get_dim()) * (self.num_frames_per_video - 1) | |
def forward(self, t: torch.Tensor) -> torch.Tensor: | |
misc.assert_shape(t, [None, self.num_frames_per_video]) | |
batch_size = t.shape[0] | |
if self.num_frames_per_video == 1: | |
out = torch.zeros(len(t), 1, device=t.device) | |
else: | |
if self.sampling_type == 'uniform': | |
num_diffs_to_use = 1 | |
t_diffs = t[:, 1] - t[:, 0] # [batch_size] | |
else: | |
num_diffs_to_use = self.num_frames_per_video - 1 | |
t_diffs = (t[:, 1:] - t[:, :-1]).view(-1) # [batch_size * (num_frames - 1)] | |
# Note: float => round => long is necessary when it's originally long | |
const_embs = self.const_embed(t_diffs.float().round().long()) # [batch_size * num_diffs_to_use, d] | |
fourier_embs = self.time_encoder(t_diffs.unsqueeze(1)) # [batch_size * num_diffs_to_use, num_fourier_feats] | |
out = torch.cat([const_embs, fourier_embs], dim=1) # [batch_size * num_diffs_to_use, d + num_fourier_feats] | |
out = out.view(batch_size, num_diffs_to_use, -1).view(batch_size, -1) # [batch_size, num_diffs_to_use * (d + num_fourier_feats)] | |
return out | |
def construct_log_spaced_freqs(max_num_frames: int, skip_small_t_freqs: int=0): | |
time_resolution = 2 ** np.ceil(np.log2(max_num_frames)) | |
num_fourier_feats = np.ceil(np.log2(time_resolution)).astype(int) | |
powers = torch.tensor([2]).repeat(num_fourier_feats).pow(torch.arange(num_fourier_feats)) # [num_fourier_feats] | |
powers = powers[:len(powers) - skip_small_t_freqs] # [num_fourier_feats] | |
fourier_coefs = powers.unsqueeze(0).float() * np.pi # [1, num_fourier_feats] | |
return fourier_coefs / time_resolution |