Spaces:
Runtime error
Runtime error
File size: 7,451 Bytes
19c4ddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
from functools import partial
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from shap_e.models.nn.checkpoint import checkpoint
from shap_e.models.nn.encoding import encode_position, maybe_encode_direction
from shap_e.models.nn.meta import MetaModule, subdict
from shap_e.models.nn.ops import MetaLinear, get_act, mlp_init
from shap_e.models.query import Query
from shap_e.util.collections import AttrDict
from .base import Model
class MLPModel(MetaModule, Model):
def __init__(
self,
n_output: int,
output_activation: str,
# Positional encoding parameters
posenc_version: str = "v1",
# Direction related channel prediction
insert_direction_at: Optional[int] = None,
# MLP parameters
d_hidden: int = 256,
n_hidden_layers: int = 4,
activation: str = "relu",
init: Optional[str] = None,
init_scale: float = 1.0,
meta_parameters: bool = False,
trainable_meta: bool = False,
meta_proj: bool = True,
meta_bias: bool = True,
meta_start: int = 0,
meta_stop: Optional[int] = None,
n_meta_layers: Optional[int] = None,
register_freqs: bool = False,
device: torch.device = torch.device("cuda"),
):
super().__init__()
if register_freqs:
self.register_buffer("freqs", 2.0 ** torch.arange(10, device=device).view(1, 10))
# Positional encoding
self.posenc_version = posenc_version
dummy = torch.eye(1, 3)
d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1]
d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1]
# Instantiate the MLP
mlp_widths = [d_hidden] * n_hidden_layers
input_widths = [d_posenc_pos, *mlp_widths]
output_widths = mlp_widths + [n_output]
self.meta_parameters = meta_parameters
# When this model is used jointly to express NeRF, it may have to
# process directions as well in which case we simply concatenate
# the direction representation at the specified layer.
self.insert_direction_at = insert_direction_at
if insert_direction_at is not None:
input_widths[self.insert_direction_at] += d_posenc_dir
linear_cls = lambda meta: (
partial(
MetaLinear,
meta_scale=False,
meta_shift=False,
meta_proj=meta_proj,
meta_bias=meta_bias,
trainable_meta=trainable_meta,
)
if meta
else nn.Linear
)
if meta_stop is None:
if n_meta_layers is not None:
assert n_meta_layers > 0
meta_stop = meta_start + n_meta_layers - 1
else:
meta_stop = n_hidden_layers
if meta_parameters:
metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)]
else:
metas = [False] * (n_hidden_layers + 1)
self.mlp = nn.ModuleList(
[
linear_cls(meta)(d_in, d_out, device=device)
for meta, d_in, d_out in zip(metas, input_widths, output_widths)
]
)
mlp_init(self.mlp, init=init, init_scale=init_scale)
self.activation = get_act(activation)
self.output_activation = get_act(output_activation)
self.device = device
self.to(device)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict:
"""
:param position: [batch_size x ... x 3]
:param params: Meta parameters
:param options: Optional hyperparameters
"""
# query.direction is None typically for SDF models and training
h_final, _h_directionless = self._mlp(
query.position, query.direction, params=params, options=options
)
return self.output_activation(h_final)
def _run_mlp(
self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:return: the final and directionless activations at the given query
"""
h_preact = h = encode_position(self.posenc_version, position=position)
h_directionless = None
for i, layer in enumerate(self.mlp):
if i == self.insert_direction_at:
h_directionless = h_preact
h_direction = maybe_encode_direction(
self.posenc_version, position=position, direction=direction
)
h = torch.cat([h, h_direction], dim=-1)
if isinstance(layer, MetaLinear):
h = layer(h, params=subdict(params, f"mlp.{i}"))
else:
h = layer(h)
h_preact = h
if i < len(self.mlp) - 1:
h = self.activation(h)
h_final = h
if h_directionless is None:
h_directionless = h_preact
return h_final, h_directionless
def _mlp(
self,
position: torch.Tensor,
direction: Optional[torch.Tensor] = None,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param position: [batch_size x ... x 3]
:param params: Meta parameters
:param options: Optional hyperparameters
:return: the final and directionless activations at the given query
"""
params = self.update(params)
options = AttrDict() if options is None else AttrDict(options)
mlp = partial(self._run_mlp, direction=direction, params=params)
parameters = []
for i, layer in enumerate(self.mlp):
if isinstance(layer, MetaLinear):
parameters.extend(list(subdict(params, f"mlp.{i}").values()))
else:
parameters.extend(layer.parameters())
h_final, h_directionless = checkpoint(
mlp, (position,), parameters, options.checkpoint_stf_model
)
return h_final, h_directionless
class MLPSDFModel(MLPModel):
def __init__(self, initial_bias: float = -0.1, **kwargs):
super().__init__(n_output=1, output_activation="identity", **kwargs)
self.mlp[-1].bias.data.fill_(initial_bias)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
signed_distance = super().forward(query=query, params=params, options=options)
return AttrDict(signed_distance=signed_distance)
class MLPTextureFieldModel(MLPModel):
def __init__(
self,
n_channels: int = 3,
**kwargs,
):
super().__init__(n_output=n_channels, output_activation="sigmoid", **kwargs)
def forward(
self,
query: Query,
params: Optional[Dict[str, torch.Tensor]] = None,
options: Optional[Dict[str, Any]] = None,
) -> AttrDict[str, Any]:
channels = super().forward(query=query, params=params, options=options)
return AttrDict(channels=channels)
|