Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from torch.autograd import Function | |
from torch.cuda.amp import custom_bwd, custom_fwd | |
from lam.models.rendering.utils.typing import * | |
def get_activation(name): | |
if name is None: | |
return lambda x: x | |
name = name.lower() | |
if name == "none": | |
return lambda x: x | |
elif name == "lin2srgb": | |
return lambda x: torch.where( | |
x > 0.0031308, | |
torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, | |
12.92 * x, | |
).clamp(0.0, 1.0) | |
elif name == "exp": | |
return lambda x: torch.exp(x) | |
elif name == "shifted_exp": | |
return lambda x: torch.exp(x - 1.0) | |
elif name == "trunc_exp": | |
return trunc_exp | |
elif name == "shifted_trunc_exp": | |
return lambda x: trunc_exp(x - 1.0) | |
elif name == "sigmoid": | |
return lambda x: torch.sigmoid(x) | |
elif name == "tanh": | |
return lambda x: torch.tanh(x) | |
elif name == "shifted_softplus": | |
return lambda x: F.softplus(x - 1.0) | |
elif name == "scale_-11_01": | |
return lambda x: x * 0.5 + 0.5 | |
else: | |
try: | |
return getattr(F, name) | |
except AttributeError: | |
raise ValueError(f"Unknown activation function: {name}") | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
dim_in: int, | |
dim_out: int, | |
n_neurons: int, | |
n_hidden_layers: int, | |
activation: str = "relu", | |
output_activation: Optional[str] = None, | |
bias: bool = True, | |
): | |
super().__init__() | |
layers = [ | |
self.make_linear( | |
dim_in, n_neurons, is_first=True, is_last=False, bias=bias | |
), | |
self.make_activation(activation), | |
] | |
for i in range(n_hidden_layers - 1): | |
layers += [ | |
self.make_linear( | |
n_neurons, n_neurons, is_first=False, is_last=False, bias=bias | |
), | |
self.make_activation(activation), | |
] | |
layers += [ | |
self.make_linear( | |
n_neurons, dim_out, is_first=False, is_last=True, bias=bias | |
) | |
] | |
self.layers = nn.Sequential(*layers) | |
self.output_activation = get_activation(output_activation) | |
def forward(self, x): | |
x = self.layers(x) | |
x = self.output_activation(x) | |
return x | |
def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True): | |
layer = nn.Linear(dim_in, dim_out, bias=bias) | |
return layer | |
def make_activation(self, activation): | |
if activation == "relu": | |
return nn.ReLU(inplace=True) | |
elif activation == "silu": | |
return nn.SiLU(inplace=True) | |
else: | |
raise NotImplementedError | |
class _TruncExp(Function): # pylint: disable=abstract-method | |
# Implementation from torch-ngp: | |
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py | |
def forward(ctx, x): # pylint: disable=arguments-differ | |
ctx.save_for_backward(x) | |
return torch.exp(x) | |
def backward(ctx, g): # pylint: disable=arguments-differ | |
x = ctx.saved_tensors[0] | |
return g * torch.exp(torch.clamp(x, max=15)) | |
trunc_exp = _TruncExp.apply |