Spaces:
Runtime error
Runtime error
from typing import Any, Dict, Optional, Tuple | |
import torch | |
from shap_e.models.nn.ops import get_act | |
from shap_e.models.query import Query | |
from shap_e.models.stf.mlp import MLPModel | |
from shap_e.util.collections import AttrDict | |
class MLPDensitySDFModel(MLPModel): | |
def __init__( | |
self, | |
initial_bias: float = -0.1, | |
sdf_activation="tanh", | |
density_activation="exp", | |
**kwargs, | |
): | |
super().__init__( | |
n_output=2, | |
output_activation="identity", | |
**kwargs, | |
) | |
self.mlp[-1].bias[0].data.fill_(initial_bias) | |
self.sdf_activation = get_act(sdf_activation) | |
self.density_activation = get_act(density_activation) | |
def forward( | |
self, | |
query: Query, | |
params: Optional[Dict[str, torch.Tensor]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
) -> AttrDict[str, Any]: | |
# query.direction is None typically for SDF models and training | |
h, _h_directionless = self._mlp( | |
query.position, query.direction, params=params, options=options | |
) | |
h_sdf, h_density = h.split(1, dim=-1) | |
return AttrDict( | |
density=self.density_activation(h_density), | |
signed_distance=self.sdf_activation(h_sdf), | |
) | |
class MLPNeRSTFModel(MLPModel): | |
def __init__( | |
self, | |
sdf_activation="tanh", | |
density_activation="exp", | |
channel_activation="sigmoid", | |
direction_dependent_shape: bool = True, # To be able to load old models. Set this to be False in future models. | |
separate_nerf_channels: bool = False, | |
separate_coarse_channels: bool = False, | |
initial_density_bias: float = 0.0, | |
initial_sdf_bias: float = -0.1, | |
**kwargs, | |
): | |
h_map, h_directionless_map = indices_for_output_mode( | |
direction_dependent_shape=direction_dependent_shape, | |
separate_nerf_channels=separate_nerf_channels, | |
separate_coarse_channels=separate_coarse_channels, | |
) | |
n_output = index_mapping_max(h_map) | |
super().__init__( | |
n_output=n_output, | |
output_activation="identity", | |
**kwargs, | |
) | |
self.direction_dependent_shape = direction_dependent_shape | |
self.separate_nerf_channels = separate_nerf_channels | |
self.separate_coarse_channels = separate_coarse_channels | |
self.sdf_activation = get_act(sdf_activation) | |
self.density_activation = get_act(density_activation) | |
self.channel_activation = get_act(channel_activation) | |
self.h_map = h_map | |
self.h_directionless_map = h_directionless_map | |
self.mlp[-1].bias.data.zero_() | |
layer = -1 if self.direction_dependent_shape else self.insert_direction_at | |
self.mlp[layer].bias[0].data.fill_(initial_sdf_bias) | |
self.mlp[layer].bias[1].data.fill_(initial_density_bias) | |
def forward( | |
self, | |
query: Query, | |
params: Optional[Dict[str, torch.Tensor]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
) -> AttrDict[str, Any]: | |
options = AttrDict() if options is None else AttrDict(options) | |
h, h_directionless = self._mlp( | |
query.position, query.direction, params=params, options=options | |
) | |
activations = map_indices_to_keys(self.h_map, h) | |
activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless)) | |
if options.nerf_level == "coarse": | |
h_density = activations.density_coarse | |
else: | |
h_density = activations.density_fine | |
if options.get("rendering_mode", "stf") == "nerf": | |
if options.nerf_level == "coarse": | |
h_channels = activations.nerf_coarse | |
else: | |
h_channels = activations.nerf_fine | |
else: | |
h_channels = activations.stf | |
return AttrDict( | |
density=self.density_activation(h_density), | |
signed_distance=self.sdf_activation(activations.sdf), | |
channels=self.channel_activation(h_channels), | |
) | |
IndexMapping = AttrDict[str, Tuple[int, int]] | |
def indices_for_output_mode( | |
direction_dependent_shape: bool, | |
separate_nerf_channels: bool, | |
separate_coarse_channels: bool, | |
) -> Tuple[IndexMapping, IndexMapping]: | |
""" | |
Get output mappings for (h, h_directionless). | |
""" | |
h_map = AttrDict() | |
h_directionless_map = AttrDict() | |
if direction_dependent_shape: | |
h_map.sdf = (0, 1) | |
if separate_coarse_channels: | |
assert separate_nerf_channels | |
h_map.density_coarse = (1, 2) | |
h_map.density_fine = (2, 3) | |
h_map.stf = (3, 6) | |
h_map.nerf_coarse = (6, 9) | |
h_map.nerf_fine = (9, 12) | |
else: | |
h_map.density_coarse = (1, 2) | |
h_map.density_fine = (1, 2) | |
if separate_nerf_channels: | |
h_map.stf = (2, 5) | |
h_map.nerf_coarse = (5, 8) | |
h_map.nerf_fine = (5, 8) | |
else: | |
h_map.stf = (2, 5) | |
h_map.nerf_coarse = (2, 5) | |
h_map.nerf_fine = (2, 5) | |
else: | |
h_directionless_map.sdf = (0, 1) | |
h_directionless_map.density_coarse = (1, 2) | |
if separate_coarse_channels: | |
h_directionless_map.density_fine = (2, 3) | |
else: | |
h_directionless_map.density_fine = h_directionless_map.density_coarse | |
h_map.stf = (0, 3) | |
if separate_coarse_channels: | |
assert separate_nerf_channels | |
h_map.nerf_coarse = (3, 6) | |
h_map.nerf_fine = (6, 9) | |
else: | |
if separate_nerf_channels: | |
h_map.nerf_coarse = (3, 6) | |
else: | |
h_map.nerf_coarse = (0, 3) | |
h_map.nerf_fine = h_map.nerf_coarse | |
return h_map, h_directionless_map | |
def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]: | |
return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()}) | |
def index_mapping_max(mapping: IndexMapping) -> int: | |
return max(end for _, (_, end) in mapping.items()) | |