Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
from dataclasses import dataclass | |
from typing import Dict, Optional, Tuple | |
import torch | |
from shap_e.models.nn.meta import MetaModule | |
from shap_e.models.nn.utils import ArrayType, safe_divide, to_torch | |
class VolumeRange: | |
t0: torch.Tensor | |
t1: torch.Tensor | |
intersected: torch.Tensor | |
def __post_init__(self): | |
assert self.t0.shape == self.t1.shape == self.intersected.shape | |
def next_t0(self): | |
""" | |
Given convex volume1 and volume2, where volume1 is contained in | |
volume2, this function returns the t0 at which rays leave volume1 and | |
intersect with volume2 \\ volume1. | |
""" | |
return self.t1 * self.intersected.float() | |
def extend(self, another: "VolumeRange") -> "VolumeRange": | |
""" | |
The ranges at which rays intersect with either one, or both, or none of | |
the self and another are merged together. | |
""" | |
return VolumeRange( | |
t0=torch.where(self.intersected, self.t0, another.t0), | |
t1=torch.where(another.intersected, another.t1, self.t1), | |
intersected=torch.logical_or(self.intersected, another.intersected), | |
) | |
def partition(self, ts) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Partitions t0 and t1 into n_samples intervals. | |
:param ts: [batch_size, *shape, n_samples, 1] | |
:return: a tuple of ( | |
lower: [batch_size, *shape, n_samples, 1] | |
upper: [batch_size, *shape, n_samples, 1] | |
delta: [batch_size, *shape, n_samples, 1] | |
) where | |
ts \\in [lower, upper] | |
deltas = upper - lower | |
""" | |
mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 | |
lower = torch.cat([self.t0[..., None, :], mids], dim=-2) | |
upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) | |
delta = upper - lower | |
assert lower.shape == upper.shape == delta.shape == ts.shape | |
return lower, upper, delta | |
class Volume(ABC): | |
""" | |
An abstraction of rendering volume. | |
""" | |
def intersect( | |
self, | |
origin: torch.Tensor, | |
direction: torch.Tensor, | |
t0_lower: Optional[torch.Tensor] = None, | |
params: Optional[Dict] = None, | |
epsilon: float = 1e-6, | |
) -> VolumeRange: | |
""" | |
:param origin: [batch_size, *shape, 3] | |
:param direction: [batch_size, *shape, 3] | |
:param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. | |
:param params: Optional meta parameters in case Volume is parametric | |
:param epsilon: to stabilize calculations | |
:return: A tuple of (t0, t1, intersected) where each has a shape | |
[batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is | |
in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed | |
to be on the boundary of the volume. | |
""" | |
class BoundingBoxVolume(MetaModule, Volume): | |
""" | |
Axis-aligned bounding box defined by the two opposite corners. | |
""" | |
def __init__( | |
self, | |
*, | |
bbox_min: ArrayType, | |
bbox_max: ArrayType, | |
min_dist: float = 0.0, | |
min_t_range: float = 1e-3, | |
device: torch.device = torch.device("cuda"), | |
): | |
""" | |
:param bbox_min: the left/bottommost corner of the bounding box | |
:param bbox_max: the other corner of the bounding box | |
:param min_dist: all rays should start at least this distance away from the origin. | |
""" | |
super().__init__() | |
self.bbox_min = to_torch(bbox_min).to(device) | |
self.bbox_max = to_torch(bbox_max).to(device) | |
self.min_dist = min_dist | |
self.min_t_range = min_t_range | |
self.bbox = torch.stack([self.bbox_min, self.bbox_max]) | |
assert self.bbox.shape == (2, 3) | |
assert self.min_dist >= 0.0 | |
assert self.min_t_range > 0.0 | |
self.device = device | |
def intersect( | |
self, | |
origin: torch.Tensor, | |
direction: torch.Tensor, | |
t0_lower: Optional[torch.Tensor] = None, | |
params: Optional[Dict] = None, | |
epsilon=1e-6, | |
) -> VolumeRange: | |
""" | |
:param origin: [batch_size, *shape, 3] | |
:param direction: [batch_size, *shape, 3] | |
:param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. | |
:param params: Optional meta parameters in case Volume is parametric | |
:param epsilon: to stabilize calculations | |
:return: A tuple of (t0, t1, intersected) where each has a shape | |
[batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is | |
in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed | |
to be on the boundary of the volume. | |
""" | |
batch_size, *shape, _ = origin.shape | |
ones = [1] * len(shape) | |
bbox = self.bbox.view(1, *ones, 2, 3) | |
# import pdb; pdb.set_trace() | |
ts = safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) | |
# Cases to think about: | |
# | |
# 1. t1 <= t0: the ray does not pass through the AABB. | |
# 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. | |
# 3. t0 <= 0 <= t1: the ray starts from inside the BB | |
# 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. | |
# | |
# 1 and 4 are clearly handled from t0 < t1 below. | |
# Making t0 at least min_dist (>= 0) takes care of 2 and 3. | |
t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) | |
t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values | |
assert t0.shape == t1.shape == (batch_size, *shape, 1) | |
if t0_lower is not None: | |
assert t0.shape == t0_lower.shape | |
t0 = torch.maximum(t0, t0_lower) | |
intersected = t0 + self.min_t_range < t1 | |
t0 = torch.where(intersected, t0, torch.zeros_like(t0)) | |
t1 = torch.where(intersected, t1, torch.ones_like(t1)) | |
return VolumeRange(t0=t0, t1=t1, intersected=intersected) | |
class UnboundedVolume(MetaModule, Volume): | |
""" | |
Originally used in NeRF. Unbounded volume but with a limited visibility | |
when rendering (e.g. objects that are farther away than the max_dist from | |
the ray origin are not considered) | |
""" | |
def __init__( | |
self, | |
*, | |
max_dist: float, | |
min_dist: float = 0.0, | |
min_t_range: float = 1e-3, | |
device: torch.device = torch.device("cuda"), | |
): | |
super().__init__() | |
self.max_dist = max_dist | |
self.min_dist = min_dist | |
self.min_t_range = min_t_range | |
assert self.min_dist >= 0.0 | |
assert self.min_t_range > 0.0 | |
self.device = device | |
def intersect( | |
self, | |
origin: torch.Tensor, | |
direction: torch.Tensor, | |
t0_lower: Optional[torch.Tensor] = None, | |
params: Optional[Dict] = None, | |
) -> VolumeRange: | |
""" | |
:param origin: [batch_size, *shape, 3] | |
:param direction: [batch_size, *shape, 3] | |
:param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. | |
:param params: Optional meta parameters in case Volume is parametric | |
:param epsilon: to stabilize calculations | |
:return: A tuple of (t0, t1, intersected) where each has a shape | |
[batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is | |
in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed | |
to be on the boundary of the volume. | |
""" | |
batch_size, *shape, _ = origin.shape | |
t0 = torch.zeros(batch_size, *shape, 1, dtype=origin.dtype, device=origin.device) | |
if t0_lower is not None: | |
t0 = torch.maximum(t0, t0_lower) | |
t1 = t0 + self.max_dist | |
t0 = t0.clamp(self.min_dist) | |
return VolumeRange(t0=t0, t1=t1, intersected=t0 + self.min_t_range < t1) | |
class SphericalVolume(MetaModule, Volume): | |
""" | |
Used in NeRF++ but will not be used probably unless we want to reproduce | |
their results. | |
""" | |
def __init__( | |
self, | |
*, | |
radius: float, | |
center: ArrayType = (0.0, 0.0, 0.0), | |
min_dist: float = 0.0, | |
min_t_range: float = 1e-3, | |
device: torch.device = torch.device("cuda"), | |
): | |
super().__init__() | |
self.radius = radius | |
self.center = to_torch(center).to(device) | |
self.min_dist = min_dist | |
self.min_t_range = min_t_range | |
assert self.min_dist >= 0.0 | |
assert self.min_t_range > 0.0 | |
self.device = device | |
def intersect( | |
self, | |
origin: torch.Tensor, | |
direction: torch.Tensor, | |
t0_lower: Optional[torch.Tensor] = None, | |
params: Optional[Dict] = None, | |
epsilon=1e-6, | |
) -> VolumeRange: | |
raise NotImplementedError | |