jammmmm's picture
Add spar3d demo files
38dbec8
raw
history blame
3.5 kB
"""Siren MLP https://www.vincentsitzmann.com/siren/"""
from typing import Optional
import numpy as np
import torch
from torch import nn
class SineLayer(nn.Module):
"""
Sine layer for the SIREN network.
"""
def __init__(
self, in_features, out_features, bias=True, is_first=False, omega_0=30.0
):
super().__init__()
self.omega_0 = omega_0
self.is_first = is_first
self.in_features = in_features
self.linear = nn.Linear(in_features, out_features, bias=bias)
self.init_weights()
def init_weights(self):
with torch.no_grad():
if self.is_first:
self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features)
else:
self.linear.weight.uniform_(
-np.sqrt(6 / self.in_features) / self.omega_0,
np.sqrt(6 / self.in_features) / self.omega_0,
)
def forward(self, x):
return torch.sin(self.omega_0 * self.linear(x))
class Siren(nn.Module):
"""Siren network.
Args:
in_dim: Input layer dimension
num_layers: Number of network layers
layer_width: Width of each MLP layer
out_dim: Output layer dimension. Uses layer_width if None.
activation: intermediate layer activation function.
out_activation: output activation function.
"""
def __init__(
self,
in_dim: int,
hidden_layers: int,
hidden_features: int,
out_dim: Optional[int] = None,
outermost_linear: bool = False,
first_omega_0: float = 30,
hidden_omega_0: float = 30,
out_activation: Optional[nn.Module] = None,
) -> None:
super().__init__()
self.in_dim = in_dim
assert self.in_dim > 0
self.out_dim = out_dim if out_dim is not None else hidden_features
self.outermost_linear = outermost_linear
self.first_omega_0 = first_omega_0
self.hidden_omega_0 = hidden_omega_0
self.hidden_layers = hidden_layers
self.layer_width = hidden_features
self.out_activation = out_activation
self.net = []
self.net.append(
SineLayer(in_dim, hidden_features, is_first=True, omega_0=first_omega_0)
)
for _ in range(hidden_layers):
self.net.append(
SineLayer(
hidden_features,
hidden_features,
is_first=False,
omega_0=hidden_omega_0,
)
)
if outermost_linear:
final_layer = nn.Linear(hidden_features, self.out_dim)
with torch.no_grad():
final_layer.weight.uniform_(
-np.sqrt(6 / hidden_features) / hidden_omega_0,
np.sqrt(6 / hidden_features) / hidden_omega_0,
)
self.net.append(final_layer)
else:
self.net.append(
SineLayer(
hidden_features,
self.out_dim,
is_first=False,
omega_0=hidden_omega_0,
)
)
if self.out_activation is not None:
self.net.append(self.out_activation)
self.net = nn.Sequential(*self.net)
def forward(self, model_input):
"""Forward pass through the network"""
output = self.net(model_input)
return output