Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from functools import partial | |
from typing import Any, Dict, Optional, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from shap_e.models.nn.checkpoint import checkpoint | |
from shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis | |
from shap_e.models.nn.meta import MetaModule, subdict | |
from shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init | |
from shap_e.models.nn.utils import ArrayType | |
from shap_e.models.query import Query | |
from shap_e.util.collections import AttrDict | |
class NeRFModel(ABC): | |
""" | |
Parametric scene representation whose outputs are integrated by NeRFRenderer | |
""" | |
def forward( | |
self, | |
query: Query, | |
params: Optional[Dict[str, torch.Tensor]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
) -> AttrDict: | |
""" | |
:param query: the points in the field to query. | |
:param params: Meta parameters | |
:param options: Optional hyperparameters | |
:return: An AttrDict containing at least | |
- density: [batch_size x ... x 1] | |
- channels: [batch_size x ... x n_channels] | |
- aux_losses: [batch_size x ... x 1] | |
""" | |
class VoidNeRFModel(MetaModule, NeRFModel): | |
""" | |
Implements the default empty space model where all queries are rendered as | |
background. | |
""" | |
def __init__( | |
self, | |
background: ArrayType, | |
trainable: bool = False, | |
channel_scale: float = 255.0, | |
device: torch.device = torch.device("cuda"), | |
): | |
super().__init__() | |
background = nn.Parameter( | |
torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device) | |
/ channel_scale | |
) | |
if trainable: | |
self.register_parameter("background", background) | |
else: | |
self.register_buffer("background", background) | |
def forward( | |
self, | |
query: Query, | |
params: Optional[Dict[str, torch.Tensor]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
) -> AttrDict: | |
_ = params | |
default_bg = self.background[None] | |
background = options.get("background", default_bg) if options is not None else default_bg | |
shape = query.position.shape[:-1] | |
ones = [1] * (len(shape) - 1) | |
n_channels = background.shape[-1] | |
background = torch.broadcast_to( | |
background.view(background.shape[0], *ones, n_channels), [*shape, n_channels] | |
) | |
return background | |
class MLPNeRFModel(MetaModule, NeRFModel): | |
def __init__( | |
self, | |
# Positional encoding parameters | |
n_levels: int = 10, | |
# MLP parameters | |
d_hidden: int = 256, | |
n_density_layers: int = 4, | |
n_channel_layers: int = 1, | |
n_channels: int = 3, | |
sh_degree: int = 4, | |
activation: str = "relu", | |
density_activation: str = "exp", | |
init: Optional[str] = None, | |
init_scale: float = 1.0, | |
output_activation: str = "sigmoid", | |
meta_parameters: bool = False, | |
trainable_meta: bool = False, | |
zero_out: bool = True, | |
register_freqs: bool = True, | |
posenc_version: str = "v1", | |
device: torch.device = torch.device("cuda"), | |
): | |
super().__init__() | |
# Positional encoding | |
if register_freqs: | |
# not used anymore | |
self.register_buffer( | |
"freqs", | |
2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels), | |
) | |
self.posenc_version = posenc_version | |
dummy = torch.eye(1, 3) | |
d_input = encode_position(posenc_version, position=dummy).shape[-1] | |
self.n_levels = n_levels | |
self.sh_degree = sh_degree | |
d_sh_coeffs = sh_degree**2 | |
self.meta_parameters = meta_parameters | |
mlp_cls = ( | |
partial( | |
MetaMLP, | |
meta_scale=False, | |
meta_shift=False, | |
meta_proj=True, | |
meta_bias=True, | |
trainable_meta=trainable_meta, | |
) | |
if meta_parameters | |
else MLP | |
) | |
self.density_mlp = mlp_cls( | |
d_input=d_input, | |
d_hidden=[d_hidden] * (n_density_layers - 1), | |
d_output=d_hidden, | |
act_name=activation, | |
init_scale=init_scale, | |
) | |
self.channel_mlp = mlp_cls( | |
d_input=d_hidden + d_sh_coeffs, | |
d_hidden=[d_hidden] * n_channel_layers, | |
d_output=n_channels, | |
act_name=activation, | |
init_scale=init_scale, | |
) | |
self.act = get_act(output_activation) | |
self.density_act = get_act(density_activation) | |
mlp_init( | |
list(self.density_mlp.affines) + list(self.channel_mlp.affines), | |
init=init, | |
init_scale=init_scale, | |
) | |
if zero_out: | |
zero_init(self.channel_mlp.affines[-1]) | |
self.to(device) | |
def encode_position(self, query: Query): | |
h = encode_position(self.posenc_version, position=query.position) | |
return h | |
def forward( | |
self, | |
query: Query, | |
params: Optional[Dict[str, torch.Tensor]] = None, | |
options: Optional[Dict[str, Any]] = None, | |
) -> AttrDict: | |
params = self.update(params) | |
options = AttrDict() if options is None else AttrDict(options) | |
query = query.copy() | |
h_position = self.encode_position(query) | |
if self.meta_parameters: | |
density_params = subdict(params, "density_mlp") | |
density_mlp = partial( | |
self.density_mlp, params=density_params, options=options, log_prefix="density_" | |
) | |
density_mlp_parameters = list(density_params.values()) | |
else: | |
density_mlp = partial(self.density_mlp, options=options, log_prefix="density_") | |
density_mlp_parameters = self.density_mlp.parameters() | |
h_density = checkpoint( | |
density_mlp, | |
(h_position,), | |
density_mlp_parameters, | |
options.checkpoint_nerf_mlp, | |
) | |
h_direction = maybe_get_spherical_harmonics_basis( | |
sh_degree=self.sh_degree, | |
coords_shape=query.position.shape, | |
coords=query.direction, | |
device=query.position.device, | |
) | |
if self.meta_parameters: | |
channel_params = subdict(params, "channel_mlp") | |
channel_mlp = partial( | |
self.channel_mlp, params=channel_params, options=options, log_prefix="channel_" | |
) | |
channel_mlp_parameters = list(channel_params.values()) | |
else: | |
channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_") | |
channel_mlp_parameters = self.channel_mlp.parameters() | |
h_channel = checkpoint( | |
channel_mlp, | |
(torch.cat([h_density, h_direction], dim=-1),), | |
channel_mlp_parameters, | |
options.checkpoint_nerf_mlp, | |
) | |
density_logit = h_density[..., :1] | |
res = AttrDict( | |
density_logit=density_logit, | |
density=self.density_act(density_logit), | |
channels=self.act(h_channel), | |
aux_losses=AttrDict(), | |
no_weight_grad_aux_losses=AttrDict(), | |
) | |
if options.return_h_density: | |
res.h_density = h_density | |
return res | |
def maybe_get_spherical_harmonics_basis( | |
sh_degree: int, | |
coords_shape: Tuple[int], | |
coords: Optional[torch.Tensor] = None, | |
device: torch.device = torch.device("cuda"), | |
) -> torch.Tensor: | |
""" | |
:param sh_degree: Spherical harmonics degree | |
:param coords_shape: [*shape, 3] | |
:param coords: optional coordinate tensor of coords_shape | |
""" | |
if coords is None: | |
return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device) | |
return spherical_harmonics_basis(coords, sh_degree) | |