Spaces:
Running
on
L40S
Running
on
L40S
"""FiLM Siren MLP as per https://marcoamonteiro.github.io/pi-GAN-website/.""" | |
from typing import Optional | |
import numpy as np | |
import torch | |
from torch import nn | |
def kaiming_leaky_init(m): | |
classname = m.__class__.__name__ | |
if classname.find("Linear") != -1: | |
torch.nn.init.kaiming_normal_( | |
m.weight, a=0.2, mode="fan_in", nonlinearity="leaky_relu" | |
) | |
def frequency_init(freq): | |
def init(m): | |
with torch.no_grad(): | |
if isinstance(m, nn.Linear): | |
num_input = m.weight.size(-1) | |
m.weight.uniform_( | |
-np.sqrt(6 / num_input) / freq, np.sqrt(6 / num_input) / freq | |
) | |
return init | |
def first_layer_film_sine_init(m): | |
with torch.no_grad(): | |
if isinstance(m, nn.Linear): | |
num_input = m.weight.size(-1) | |
m.weight.uniform_(-1 / num_input, 1 / num_input) | |
class CustomMappingNetwork(nn.Module): | |
def __init__(self, in_features, map_hidden_layers, map_hidden_dim, map_output_dim): | |
super().__init__() | |
self.network = [] | |
for _ in range(map_hidden_layers): | |
self.network.append(nn.Linear(in_features, map_hidden_dim)) | |
self.network.append(nn.LeakyReLU(0.2, inplace=True)) | |
in_features = map_hidden_dim | |
self.network.append(nn.Linear(map_hidden_dim, map_output_dim)) | |
self.network = nn.Sequential(*self.network) | |
self.network.apply(kaiming_leaky_init) | |
with torch.no_grad(): | |
self.network[-1].weight *= 0.25 | |
def forward(self, z): | |
frequencies_offsets = self.network(z) | |
frequencies = frequencies_offsets[ | |
..., : torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") | |
] | |
phase_shifts = frequencies_offsets[ | |
..., torch.div(frequencies_offsets.shape[-1], 2, rounding_mode="floor") : | |
] | |
return frequencies, phase_shifts | |
class FiLMLayer(nn.Module): | |
def __init__(self, input_dim, hidden_dim): | |
super().__init__() | |
self.layer = nn.Linear(input_dim, hidden_dim) | |
def forward(self, x, freq, phase_shift): | |
x = self.layer(x) | |
freq = freq.expand_as(x) | |
phase_shift = phase_shift.expand_as(x) | |
return torch.sin(freq * x + phase_shift) | |
class FiLMSiren(nn.Module): | |
"""FiLM Conditioned Siren network.""" | |
def __init__( | |
self, | |
in_dim: int, | |
hidden_layers: int, | |
hidden_features: int, | |
mapping_network_in_dim: int, | |
mapping_network_layers: int, | |
mapping_network_features: int, | |
out_dim: int, | |
outermost_linear: bool = False, | |
out_activation: Optional[nn.Module] = None, | |
) -> None: | |
super().__init__() | |
self.in_dim = in_dim | |
assert self.in_dim > 0 | |
self.out_dim = out_dim if out_dim is not None else hidden_features | |
self.hidden_layers = hidden_layers | |
self.hidden_features = hidden_features | |
self.mapping_network_in_dim = mapping_network_in_dim | |
self.mapping_network_layers = mapping_network_layers | |
self.mapping_network_features = mapping_network_features | |
self.outermost_linear = outermost_linear | |
self.out_activation = out_activation | |
self.net = nn.ModuleList() | |
self.net.append(FiLMLayer(self.in_dim, self.hidden_features)) | |
for _ in range(self.hidden_layers - 1): | |
self.net.append(FiLMLayer(self.hidden_features, self.hidden_features)) | |
self.final_layer = None | |
if self.outermost_linear: | |
self.final_layer = nn.Linear(self.hidden_features, self.out_dim) | |
self.final_layer.apply(frequency_init(25)) | |
else: | |
final_layer = FiLMLayer(self.hidden_features, self.out_dim) | |
self.net.append(final_layer) | |
self.mapping_network = CustomMappingNetwork( | |
in_features=self.mapping_network_in_dim, | |
map_hidden_layers=self.mapping_network_layers, | |
map_hidden_dim=self.mapping_network_features, | |
map_output_dim=(len(self.net)) * self.hidden_features * 2, | |
) | |
self.net.apply(frequency_init(25)) | |
self.net[0].apply(first_layer_film_sine_init) | |
def forward_with_frequencies_phase_shifts(self, x, frequencies, phase_shifts): | |
"""Get conditiional frequencies and phase shifts from mapping network.""" | |
frequencies = frequencies * 15 + 30 | |
for index, layer in enumerate(self.net): | |
start = index * self.hidden_features | |
end = (index + 1) * self.hidden_features | |
x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) | |
x = self.final_layer(x) if self.final_layer is not None else x | |
output = self.out_activation(x) if self.out_activation is not None else x | |
return output | |
def forward(self, x, conditioning_input): | |
"""Forward pass.""" | |
frequencies, phase_shifts = self.mapping_network(conditioning_input) | |
return self.forward_with_frequencies_phase_shifts(x, frequencies, phase_shifts) | |