File size: 9,095 Bytes
19c4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Optional, Tuple

import torch

from shap_e.models.nn.meta import MetaModule
from shap_e.models.nn.utils import ArrayType, safe_divide, to_torch


@dataclass
class VolumeRange:
    t0: torch.Tensor
    t1: torch.Tensor
    intersected: torch.Tensor

    def __post_init__(self):
        assert self.t0.shape == self.t1.shape == self.intersected.shape

    def next_t0(self):
        """
        Given convex volume1 and volume2, where volume1 is contained in
        volume2, this function returns the t0 at which rays leave volume1 and
        intersect with volume2 \\ volume1.
        """
        return self.t1 * self.intersected.float()

    def extend(self, another: "VolumeRange") -> "VolumeRange":
        """
        The ranges at which rays intersect with either one, or both, or none of
        the self and another are merged together.
        """
        return VolumeRange(
            t0=torch.where(self.intersected, self.t0, another.t0),
            t1=torch.where(another.intersected, another.t1, self.t1),
            intersected=torch.logical_or(self.intersected, another.intersected),
        )

    def partition(self, ts) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Partitions t0 and t1 into n_samples intervals.

        :param ts: [batch_size, *shape, n_samples, 1]
        :return: a tuple of (
            lower: [batch_size, *shape, n_samples, 1]
            upper: [batch_size, *shape, n_samples, 1]
            delta: [batch_size, *shape, n_samples, 1]
        ) where

            ts \\in [lower, upper]
            deltas = upper - lower
        """
        mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5
        lower = torch.cat([self.t0[..., None, :], mids], dim=-2)
        upper = torch.cat([mids, self.t1[..., None, :]], dim=-2)
        delta = upper - lower
        assert lower.shape == upper.shape == delta.shape == ts.shape
        return lower, upper, delta


class Volume(ABC):
    """
    An abstraction of rendering volume.
    """

    @abstractmethod
    def intersect(
        self,
        origin: torch.Tensor,
        direction: torch.Tensor,
        t0_lower: Optional[torch.Tensor] = None,
        params: Optional[Dict] = None,
        epsilon: float = 1e-6,
    ) -> VolumeRange:
        """
        :param origin: [batch_size, *shape, 3]
        :param direction: [batch_size, *shape, 3]
        :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
        :param params: Optional meta parameters in case Volume is parametric
        :param epsilon: to stabilize calculations

        :return: A tuple of (t0, t1, intersected) where each has a shape
            [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is
            in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed
            to be on the boundary of the volume.
        """


class BoundingBoxVolume(MetaModule, Volume):
    """
    Axis-aligned bounding box defined by the two opposite corners.
    """

    def __init__(
        self,
        *,
        bbox_min: ArrayType,
        bbox_max: ArrayType,
        min_dist: float = 0.0,
        min_t_range: float = 1e-3,
        device: torch.device = torch.device("cuda"),
    ):
        """
        :param bbox_min: the left/bottommost corner of the bounding box
        :param bbox_max: the other corner of the bounding box
        :param min_dist: all rays should start at least this distance away from the origin.
        """
        super().__init__()

        self.bbox_min = to_torch(bbox_min).to(device)
        self.bbox_max = to_torch(bbox_max).to(device)
        self.min_dist = min_dist
        self.min_t_range = min_t_range
        self.bbox = torch.stack([self.bbox_min, self.bbox_max])
        assert self.bbox.shape == (2, 3)
        assert self.min_dist >= 0.0
        assert self.min_t_range > 0.0
        self.device = device

    def intersect(
        self,
        origin: torch.Tensor,
        direction: torch.Tensor,
        t0_lower: Optional[torch.Tensor] = None,
        params: Optional[Dict] = None,
        epsilon=1e-6,
    ) -> VolumeRange:
        """
        :param origin: [batch_size, *shape, 3]
        :param direction: [batch_size, *shape, 3]
        :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
        :param params: Optional meta parameters in case Volume is parametric
        :param epsilon: to stabilize calculations

        :return: A tuple of (t0, t1, intersected) where each has a shape
            [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is
            in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed
            to be on the boundary of the volume.
        """

        batch_size, *shape, _ = origin.shape
        ones = [1] * len(shape)
        bbox = self.bbox.view(1, *ones, 2, 3)
        # import pdb; pdb.set_trace()
        ts = safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon)

        # Cases to think about:
        #
        #   1. t1 <= t0: the ray does not pass through the AABB.
        #   2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin.
        #   3. t0 <= 0 <= t1: the ray starts from inside the BB
        #   4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice.
        #
        # 1 and 4 are clearly handled from t0 < t1 below.
        # Making t0 at least min_dist (>= 0) takes care of 2 and 3.
        t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist)
        t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values
        assert t0.shape == t1.shape == (batch_size, *shape, 1)
        if t0_lower is not None:
            assert t0.shape == t0_lower.shape
            t0 = torch.maximum(t0, t0_lower)

        intersected = t0 + self.min_t_range < t1
        t0 = torch.where(intersected, t0, torch.zeros_like(t0))
        t1 = torch.where(intersected, t1, torch.ones_like(t1))

        return VolumeRange(t0=t0, t1=t1, intersected=intersected)


class UnboundedVolume(MetaModule, Volume):
    """
    Originally used in NeRF. Unbounded volume but with a limited visibility
    when rendering (e.g. objects that are farther away than the max_dist from
    the ray origin are not considered)
    """

    def __init__(
        self,
        *,
        max_dist: float,
        min_dist: float = 0.0,
        min_t_range: float = 1e-3,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__()
        self.max_dist = max_dist
        self.min_dist = min_dist
        self.min_t_range = min_t_range
        assert self.min_dist >= 0.0
        assert self.min_t_range > 0.0
        self.device = device

    def intersect(
        self,
        origin: torch.Tensor,
        direction: torch.Tensor,
        t0_lower: Optional[torch.Tensor] = None,
        params: Optional[Dict] = None,
    ) -> VolumeRange:
        """
        :param origin: [batch_size, *shape, 3]
        :param direction: [batch_size, *shape, 3]
        :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume.
        :param params: Optional meta parameters in case Volume is parametric
        :param epsilon: to stabilize calculations

        :return: A tuple of (t0, t1, intersected) where each has a shape
            [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is
            in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed
            to be on the boundary of the volume.
        """

        batch_size, *shape, _ = origin.shape
        t0 = torch.zeros(batch_size, *shape, 1, dtype=origin.dtype, device=origin.device)
        if t0_lower is not None:
            t0 = torch.maximum(t0, t0_lower)
        t1 = t0 + self.max_dist
        t0 = t0.clamp(self.min_dist)
        return VolumeRange(t0=t0, t1=t1, intersected=t0 + self.min_t_range < t1)


class SphericalVolume(MetaModule, Volume):
    """
    Used in NeRF++ but will not be used probably unless we want to reproduce
    their results.
    """

    def __init__(
        self,
        *,
        radius: float,
        center: ArrayType = (0.0, 0.0, 0.0),
        min_dist: float = 0.0,
        min_t_range: float = 1e-3,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__()

        self.radius = radius
        self.center = to_torch(center).to(device)
        self.min_dist = min_dist
        self.min_t_range = min_t_range
        assert self.min_dist >= 0.0
        assert self.min_t_range > 0.0
        self.device = device

    def intersect(
        self,
        origin: torch.Tensor,
        direction: torch.Tensor,
        t0_lower: Optional[torch.Tensor] = None,
        params: Optional[Dict] = None,
        epsilon=1e-6,
    ) -> VolumeRange:
        raise NotImplementedError