Spaces:
Running
on
L40S
Running
on
L40S
"""Siren MLP https://www.vincentsitzmann.com/siren/""" | |
from typing import Optional | |
import numpy as np | |
import torch | |
from torch import nn | |
class SineLayer(nn.Module): | |
""" | |
Sine layer for the SIREN network. | |
""" | |
def __init__( | |
self, in_features, out_features, bias=True, is_first=False, omega_0=30.0 | |
): | |
super().__init__() | |
self.omega_0 = omega_0 | |
self.is_first = is_first | |
self.in_features = in_features | |
self.linear = nn.Linear(in_features, out_features, bias=bias) | |
self.init_weights() | |
def init_weights(self): | |
with torch.no_grad(): | |
if self.is_first: | |
self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) | |
else: | |
self.linear.weight.uniform_( | |
-np.sqrt(6 / self.in_features) / self.omega_0, | |
np.sqrt(6 / self.in_features) / self.omega_0, | |
) | |
def forward(self, x): | |
return torch.sin(self.omega_0 * self.linear(x)) | |
class Siren(nn.Module): | |
"""Siren network. | |
Args: | |
in_dim: Input layer dimension | |
num_layers: Number of network layers | |
layer_width: Width of each MLP layer | |
out_dim: Output layer dimension. Uses layer_width if None. | |
activation: intermediate layer activation function. | |
out_activation: output activation function. | |
""" | |
def __init__( | |
self, | |
in_dim: int, | |
hidden_layers: int, | |
hidden_features: int, | |
out_dim: Optional[int] = None, | |
outermost_linear: bool = False, | |
first_omega_0: float = 30, | |
hidden_omega_0: float = 30, | |
out_activation: Optional[nn.Module] = None, | |
) -> None: | |
super().__init__() | |
self.in_dim = in_dim | |
assert self.in_dim > 0 | |
self.out_dim = out_dim if out_dim is not None else hidden_features | |
self.outermost_linear = outermost_linear | |
self.first_omega_0 = first_omega_0 | |
self.hidden_omega_0 = hidden_omega_0 | |
self.hidden_layers = hidden_layers | |
self.layer_width = hidden_features | |
self.out_activation = out_activation | |
self.net = [] | |
self.net.append( | |
SineLayer(in_dim, hidden_features, is_first=True, omega_0=first_omega_0) | |
) | |
for _ in range(hidden_layers): | |
self.net.append( | |
SineLayer( | |
hidden_features, | |
hidden_features, | |
is_first=False, | |
omega_0=hidden_omega_0, | |
) | |
) | |
if outermost_linear: | |
final_layer = nn.Linear(hidden_features, self.out_dim) | |
with torch.no_grad(): | |
final_layer.weight.uniform_( | |
-np.sqrt(6 / hidden_features) / hidden_omega_0, | |
np.sqrt(6 / hidden_features) / hidden_omega_0, | |
) | |
self.net.append(final_layer) | |
else: | |
self.net.append( | |
SineLayer( | |
hidden_features, | |
self.out_dim, | |
is_first=False, | |
omega_0=hidden_omega_0, | |
) | |
) | |
if self.out_activation is not None: | |
self.net.append(self.out_activation) | |
self.net = nn.Sequential(*self.net) | |
def forward(self, model_input): | |
"""Forward pass through the network""" | |
output = self.net(model_input) | |
return output | |