import torch import torch.nn as nn from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd from lam.models.rendering.utils.typing import * def get_activation(name): if name is None: return lambda x: x name = name.lower() if name == "none": return lambda x: x elif name == "lin2srgb": return lambda x: torch.where( x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055, 12.92 * x, ).clamp(0.0, 1.0) elif name == "exp": return lambda x: torch.exp(x) elif name == "shifted_exp": return lambda x: torch.exp(x - 1.0) elif name == "trunc_exp": return trunc_exp elif name == "shifted_trunc_exp": return lambda x: trunc_exp(x - 1.0) elif name == "sigmoid": return lambda x: torch.sigmoid(x) elif name == "tanh": return lambda x: torch.tanh(x) elif name == "shifted_softplus": return lambda x: F.softplus(x - 1.0) elif name == "scale_-11_01": return lambda x: x * 0.5 + 0.5 else: try: return getattr(F, name) except AttributeError: raise ValueError(f"Unknown activation function: {name}") class MLP(nn.Module): def __init__( self, dim_in: int, dim_out: int, n_neurons: int, n_hidden_layers: int, activation: str = "relu", output_activation: Optional[str] = None, bias: bool = True, ): super().__init__() layers = [ self.make_linear( dim_in, n_neurons, is_first=True, is_last=False, bias=bias ), self.make_activation(activation), ] for i in range(n_hidden_layers - 1): layers += [ self.make_linear( n_neurons, n_neurons, is_first=False, is_last=False, bias=bias ), self.make_activation(activation), ] layers += [ self.make_linear( n_neurons, dim_out, is_first=False, is_last=True, bias=bias ) ] self.layers = nn.Sequential(*layers) self.output_activation = get_activation(output_activation) def forward(self, x): x = self.layers(x) x = self.output_activation(x) return x def make_linear(self, dim_in, dim_out, is_first, is_last, bias=True): layer = nn.Linear(dim_in, dim_out, bias=bias) return layer def make_activation(self, activation): if activation == "relu": return nn.ReLU(inplace=True) elif activation == "silu": return nn.SiLU(inplace=True) else: raise NotImplementedError class _TruncExp(Function): # pylint: disable=abstract-method # Implementation from torch-ngp: # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, x): # pylint: disable=arguments-differ ctx.save_for_backward(x) return torch.exp(x) @staticmethod @custom_bwd def backward(ctx, g): # pylint: disable=arguments-differ x = ctx.saved_tensors[0] return g * torch.exp(torch.clamp(x, max=15)) trunc_exp = _TruncExp.apply