Spaces:
Runtime error
Runtime error
from functools import partial | |
from typing import Any, Dict, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
from shap_e.models.nn.checkpoint import checkpoint | |
from shap_e.models.nn.encoding import encode_position, maybe_encode_direction | |
from shap_e.models.nn.meta import MetaModule, subdict | |
from shap_e.models.nn.ops import MetaLinear, get_act, mlp_init | |
from shap_e.models.query import Query | |
from shap_e.util.collections import AttrDict | |
from .base import Model | |
class MLPModel(MetaModule, Model): | |
def __init__( | |
self, | |
n_output: int, | |
output_activation: str, | |
# Positional encoding parameters | |
posenc_version: str = "v1", | |
# Direction related channel prediction | |
insert_direction_at: Optional[int] = None, | |
# MLP parameters | |
d_hidden: int = 256, | |
n_hidden_layers: int = 4, | |
activation: str = "relu", | |
init: Optional[str] = None, | |
init_scale: float = 1.0, | |
meta_parameters: bool = False, | |
trainable_meta: bool = False, | |
meta_proj: bool = True, | |
meta_bias: bool = True, | |
meta_start: int = 0, | |
meta_stop: Optional[int] = None, | |
n_meta_layers: Optional[int] = None, | |
register_freqs: bool = False, | |
device: torch.device = torch.device("cuda"), | |
): | |
super().__init__() | |
if register_freqs: | |
self.register_buffer("freqs", 2.0 ** torch.arange(10, device=device).view(1, 10)) | |
# Positional encoding | |
self.posenc_version = posenc_version | |
dummy = torch.eye(1, 3) | |
d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1] | |
d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1] | |
# Instantiate the MLP | |
mlp_widths = [d_hidden] * n_hidden_layers | |
input_widths = [d_posenc_pos, *mlp_widths] | |
output_widths = mlp_widths + [n_output] | |
self.meta_parameters = meta_parameters | |
# When this model is used jointly to express NeRF, it may have to | |
# process directions as well in which case we simply concatenate | |
# the direction representation at the specified layer. | |
self.insert_direction_at = insert_direction_at | |
if insert_direction_at is not None: | |
input_widths[self.insert_direction_at] += d_posenc_dir | |
linear_cls = lambda meta: ( | |
partial( | |
MetaLinear, | |
meta_scale=False, | |
meta_shift=False, | |
meta_proj=meta_proj, | |
meta_bias=meta_bias, | |
trainable_meta=trainable_meta, | |
) | |
if meta | |
else nn.Linear | |
) | |
if meta_stop is None: | |
if n_meta_layers is not None: | |
assert n_meta_layers > 0 | |
meta_stop = meta_start + n_meta_layers - 1 | |
else: | |
meta_stop = n_hidden_layers | |
if meta_parameters: | |
metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)] | |
else: | |
metas = [False] * (n_hidden_layers + 1) | |
self.mlp = nn.ModuleList( | |
[ | |
linear_cls(meta)(d_in, d_out, device=device) | |
for meta, d_in, d_out in zip(metas, input_widths, output_widths) | |
] | |
) | |
mlp_init(self.mlp, init=init, init_scale=init_scale) | |
self.activation = get_act(activation) | |
self.output_activation = get_act(output_activation) | |
self.device = device | |
self.to(device) | |
def forward( | |
self, | |
query: Query, | |
params: Optional[Dict[str, torch.Tensor]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
) -> AttrDict: | |
""" | |
:param position: [batch_size x ... x 3] | |
:param params: Meta parameters | |
:param options: Optional hyperparameters | |
""" | |
# query.direction is None typically for SDF models and training | |
h_final, _h_directionless = self._mlp( | |
query.position, query.direction, params=params, options=options | |
) | |
return self.output_activation(h_final) | |
def _run_mlp( | |
self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor] | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
:return: the final and directionless activations at the given query | |
""" | |
h_preact = h = encode_position(self.posenc_version, position=position) | |
h_directionless = None | |
for i, layer in enumerate(self.mlp): | |
if i == self.insert_direction_at: | |
h_directionless = h_preact | |
h_direction = maybe_encode_direction( | |
self.posenc_version, position=position, direction=direction | |
) | |
h = torch.cat([h, h_direction], dim=-1) | |
if isinstance(layer, MetaLinear): | |
h = layer(h, params=subdict(params, f"mlp.{i}")) | |
else: | |
h = layer(h) | |
h_preact = h | |
if i < len(self.mlp) - 1: | |
h = self.activation(h) | |
h_final = h | |
if h_directionless is None: | |
h_directionless = h_preact | |
return h_final, h_directionless | |
def _mlp( | |
self, | |
position: torch.Tensor, | |
direction: Optional[torch.Tensor] = None, | |
params: Optional[Dict[str, torch.Tensor]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
:param position: [batch_size x ... x 3] | |
:param params: Meta parameters | |
:param options: Optional hyperparameters | |
:return: the final and directionless activations at the given query | |
""" | |
params = self.update(params) | |
options = AttrDict() if options is None else AttrDict(options) | |
mlp = partial(self._run_mlp, direction=direction, params=params) | |
parameters = [] | |
for i, layer in enumerate(self.mlp): | |
if isinstance(layer, MetaLinear): | |
parameters.extend(list(subdict(params, f"mlp.{i}").values())) | |
else: | |
parameters.extend(layer.parameters()) | |
h_final, h_directionless = checkpoint( | |
mlp, (position,), parameters, options.checkpoint_stf_model | |
) | |
return h_final, h_directionless | |
class MLPSDFModel(MLPModel): | |
def __init__(self, initial_bias: float = -0.1, **kwargs): | |
super().__init__(n_output=1, output_activation="identity", **kwargs) | |
self.mlp[-1].bias.data.fill_(initial_bias) | |
def forward( | |
self, | |
query: Query, | |
params: Optional[Dict[str, torch.Tensor]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
) -> AttrDict[str, Any]: | |
signed_distance = super().forward(query=query, params=params, options=options) | |
return AttrDict(signed_distance=signed_distance) | |
class MLPTextureFieldModel(MLPModel): | |
def __init__( | |
self, | |
n_channels: int = 3, | |
**kwargs, | |
): | |
super().__init__(n_output=n_channels, output_activation="sigmoid", **kwargs) | |
def forward( | |
self, | |
query: Query, | |
params: Optional[Dict[str, torch.Tensor]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
) -> AttrDict[str, Any]: | |
channels = super().forward(query=query, params=params, options=options) | |
return AttrDict(channels=channels) | |