File size: 6,207 Bytes
19c4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from typing import Any, Dict, Optional, Tuple

import torch

from shap_e.models.nn.ops import get_act
from shap_e.models.query import Query
from shap_e.models.stf.mlp import MLPModel
from shap_e.util.collections import AttrDict


class MLPDensitySDFModel(MLPModel):
    def __init__(
        self,
        initial_bias: float = -0.1,
        sdf_activation="tanh",
        density_activation="exp",
        **kwargs,
    ):
        super().__init__(
            n_output=2,
            output_activation="identity",
            **kwargs,
        )
        self.mlp[-1].bias[0].data.fill_(initial_bias)
        self.sdf_activation = get_act(sdf_activation)
        self.density_activation = get_act(density_activation)

    def forward(
        self,
        query: Query,
        params: Optional[Dict[str, torch.Tensor]] = None,
        options: Optional[Dict[str, Any]] = None,
    ) -> AttrDict[str, Any]:
        # query.direction is None typically for SDF models and training
        h, _h_directionless = self._mlp(
            query.position, query.direction, params=params, options=options
        )
        h_sdf, h_density = h.split(1, dim=-1)
        return AttrDict(
            density=self.density_activation(h_density),
            signed_distance=self.sdf_activation(h_sdf),
        )


class MLPNeRSTFModel(MLPModel):
    def __init__(
        self,
        sdf_activation="tanh",
        density_activation="exp",
        channel_activation="sigmoid",
        direction_dependent_shape: bool = True,  # To be able to load old models. Set this to be False in future models.
        separate_nerf_channels: bool = False,
        separate_coarse_channels: bool = False,
        initial_density_bias: float = 0.0,
        initial_sdf_bias: float = -0.1,
        **kwargs,
    ):
        h_map, h_directionless_map = indices_for_output_mode(
            direction_dependent_shape=direction_dependent_shape,
            separate_nerf_channels=separate_nerf_channels,
            separate_coarse_channels=separate_coarse_channels,
        )
        n_output = index_mapping_max(h_map)
        super().__init__(
            n_output=n_output,
            output_activation="identity",
            **kwargs,
        )
        self.direction_dependent_shape = direction_dependent_shape
        self.separate_nerf_channels = separate_nerf_channels
        self.separate_coarse_channels = separate_coarse_channels
        self.sdf_activation = get_act(sdf_activation)
        self.density_activation = get_act(density_activation)
        self.channel_activation = get_act(channel_activation)
        self.h_map = h_map
        self.h_directionless_map = h_directionless_map
        self.mlp[-1].bias.data.zero_()
        layer = -1 if self.direction_dependent_shape else self.insert_direction_at
        self.mlp[layer].bias[0].data.fill_(initial_sdf_bias)
        self.mlp[layer].bias[1].data.fill_(initial_density_bias)

    def forward(
        self,
        query: Query,
        params: Optional[Dict[str, torch.Tensor]] = None,
        options: Optional[Dict[str, Any]] = None,
    ) -> AttrDict[str, Any]:

        options = AttrDict() if options is None else AttrDict(options)
        h, h_directionless = self._mlp(
            query.position, query.direction, params=params, options=options
        )
        activations = map_indices_to_keys(self.h_map, h)
        activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless))

        if options.nerf_level == "coarse":
            h_density = activations.density_coarse
        else:
            h_density = activations.density_fine

        if options.get("rendering_mode", "stf") == "nerf":
            if options.nerf_level == "coarse":
                h_channels = activations.nerf_coarse
            else:
                h_channels = activations.nerf_fine
        else:
            h_channels = activations.stf
        return AttrDict(
            density=self.density_activation(h_density),
            signed_distance=self.sdf_activation(activations.sdf),
            channels=self.channel_activation(h_channels),
        )


IndexMapping = AttrDict[str, Tuple[int, int]]


def indices_for_output_mode(
    direction_dependent_shape: bool,
    separate_nerf_channels: bool,
    separate_coarse_channels: bool,
) -> Tuple[IndexMapping, IndexMapping]:
    """
    Get output mappings for (h, h_directionless).
    """
    h_map = AttrDict()
    h_directionless_map = AttrDict()
    if direction_dependent_shape:
        h_map.sdf = (0, 1)
        if separate_coarse_channels:
            assert separate_nerf_channels
            h_map.density_coarse = (1, 2)
            h_map.density_fine = (2, 3)
            h_map.stf = (3, 6)
            h_map.nerf_coarse = (6, 9)
            h_map.nerf_fine = (9, 12)
        else:
            h_map.density_coarse = (1, 2)
            h_map.density_fine = (1, 2)
            if separate_nerf_channels:
                h_map.stf = (2, 5)
                h_map.nerf_coarse = (5, 8)
                h_map.nerf_fine = (5, 8)
            else:
                h_map.stf = (2, 5)
                h_map.nerf_coarse = (2, 5)
                h_map.nerf_fine = (2, 5)
    else:
        h_directionless_map.sdf = (0, 1)
        h_directionless_map.density_coarse = (1, 2)
        if separate_coarse_channels:
            h_directionless_map.density_fine = (2, 3)
        else:
            h_directionless_map.density_fine = h_directionless_map.density_coarse
        h_map.stf = (0, 3)
        if separate_coarse_channels:
            assert separate_nerf_channels
            h_map.nerf_coarse = (3, 6)
            h_map.nerf_fine = (6, 9)
        else:
            if separate_nerf_channels:
                h_map.nerf_coarse = (3, 6)
            else:
                h_map.nerf_coarse = (0, 3)
            h_map.nerf_fine = h_map.nerf_coarse
    return h_map, h_directionless_map


def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]:
    return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()})


def index_mapping_max(mapping: IndexMapping) -> int:
    return max(end for _, (_, end) in mapping.items())