Spaces:
Runtime error
Runtime error
from functools import partial | |
from typing import Any, Dict, Optional, Sequence, Tuple, Union | |
import torch | |
from shap_e.models.nerf.model import NeRFModel | |
from shap_e.models.nerf.ray import RayVolumeIntegral, StratifiedRaySampler, render_rays | |
from shap_e.models.nn.meta import subdict | |
from shap_e.models.nn.utils import to_torch | |
from shap_e.models.query import Query | |
from shap_e.models.renderer import RayRenderer, render_views_from_rays | |
from shap_e.models.stf.base import Model | |
from shap_e.models.stf.renderer import STFRendererBase, render_views_from_stf | |
from shap_e.models.volume import BoundingBoxVolume, Volume | |
from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR | |
from shap_e.util.collections import AttrDict | |
class NeRSTFRenderer(RayRenderer, STFRendererBase): | |
def __init__( | |
self, | |
sdf: Optional[Model], | |
tf: Optional[Model], | |
nerstf: Optional[Model], | |
void: NeRFModel, | |
volume: Volume, | |
grid_size: int, | |
n_coarse_samples: int, | |
n_fine_samples: int, | |
importance_sampling_options: Optional[Dict[str, Any]] = None, | |
separate_shared_samples: bool = False, | |
texture_channels: Sequence[str] = ("R", "G", "B"), | |
channel_scale: Sequence[float] = (255.0, 255.0, 255.0), | |
ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, | |
diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, | |
specular_color: Union[float, Tuple[float]] = 0.0, | |
output_srgb: bool = True, | |
device: torch.device = torch.device("cuda"), | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume" | |
assert (nerstf is not None) ^ (sdf is not None and tf is not None) | |
self.sdf = sdf | |
self.tf = tf | |
self.nerstf = nerstf | |
self.void = void | |
self.volume = volume | |
self.grid_size = grid_size | |
self.n_coarse_samples = n_coarse_samples | |
self.n_fine_samples = n_fine_samples | |
self.importance_sampling_options = AttrDict(importance_sampling_options or {}) | |
self.separate_shared_samples = separate_shared_samples | |
self.texture_channels = texture_channels | |
self.channel_scale = to_torch(channel_scale).to(device) | |
self.ambient_color = ambient_color | |
self.diffuse_color = diffuse_color | |
self.specular_color = specular_color | |
self.output_srgb = output_srgb | |
self.device = device | |
self.patch_size=128 | |
self.use_patch=False | |
self.to(device) | |
def _query( | |
self, | |
query: Query, | |
params: AttrDict[str, torch.Tensor], | |
options: AttrDict[str, Any], | |
) -> AttrDict: | |
no_dir_query = query.copy() | |
no_dir_query.direction = None | |
if options.get("rendering_mode", "stf") == "stf": | |
assert query.direction is None | |
if self.nerstf is not None: | |
sdf = tf = self.nerstf( | |
query, | |
params=subdict(params, "nerstf"), | |
options=options, | |
) | |
else: | |
sdf = self.sdf(no_dir_query, params=subdict(params, "sdf"), options=options) | |
tf = self.tf(query, params=subdict(params, "tf"), options=options) | |
return AttrDict( | |
density=sdf.density, | |
signed_distance=sdf.signed_distance, | |
channels=tf.channels, | |
aux_losses=dict(), | |
) | |
def render_rays( | |
self, | |
batch: AttrDict, | |
params: Optional[Dict] = None, | |
options: Optional[AttrDict] = None, | |
) -> AttrDict: | |
""" | |
:param batch: has | |
- rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray. | |
:param options: Optional[Dict] | |
""" | |
params = self.update(params) | |
options = AttrDict() if options is None else AttrDict(options) | |
# Necessary to tell the TF to use specific NeRF channels. | |
options.rendering_mode = "nerf" | |
model = partial(self._query, params=params, options=options) | |
# First, render rays with coarse, stratified samples. | |
options.nerf_level = "coarse" | |
parts = [ | |
RayVolumeIntegral( | |
model=model, | |
volume=self.volume, | |
sampler=StratifiedRaySampler(), | |
n_samples=self.n_coarse_samples, | |
), | |
] | |
coarse_results, samplers, coarse_raw_outputs = render_rays( | |
batch.rays, | |
parts, | |
self.void, | |
shared=not self.separate_shared_samples, | |
render_with_direction=options.render_with_direction, | |
importance_sampling_options=self.importance_sampling_options, | |
) | |
# Then, render with additional importance-weighted ray samples. | |
options.nerf_level = "fine" | |
parts = [ | |
RayVolumeIntegral( | |
model=model, | |
volume=self.volume, | |
sampler=samplers[0], | |
n_samples=self.n_fine_samples, | |
), | |
] | |
fine_results, _, raw_outputs = render_rays( | |
batch.rays, | |
parts, | |
self.void, | |
shared=not self.separate_shared_samples, | |
prev_raw_outputs=coarse_raw_outputs, | |
render_with_direction=options.render_with_direction, | |
) | |
raw = raw_outputs[0] | |
aux_losses = fine_results.output.aux_losses.copy() | |
if self.separate_shared_samples: | |
for key, val in coarse_results.output.aux_losses.items(): | |
aux_losses[key + "_coarse"] = val | |
channels = fine_results.output.channels | |
shape = [1] * (channels.ndim - 1) + [len(self.texture_channels)] | |
channels = channels * self.channel_scale.view(*shape) | |
res = AttrDict( | |
channels=channels, | |
transmittance=fine_results.transmittance, | |
raw_signed_distance=raw.signed_distance, | |
raw_density=raw.density, | |
distances=fine_results.output.distances, | |
t0=fine_results.volume_range.t0, | |
t1=fine_results.volume_range.t1, | |
intersected=fine_results.volume_range.intersected, | |
aux_losses=aux_losses, | |
) | |
if self.separate_shared_samples: | |
res.update( | |
dict( | |
channels_coarse=( | |
coarse_results.output.channels * self.channel_scale.view(*shape) | |
), | |
distances_coarse=coarse_results.output.distances, | |
transmittance_coarse=coarse_results.transmittance, | |
) | |
) | |
return res | |
def render_views( | |
self, | |
batch: AttrDict, | |
params: Optional[Dict] = None, | |
options: Optional[AttrDict] = None, | |
) -> AttrDict: | |
""" | |
Returns a backproppable rendering of a view | |
:param batch: contains either ["poses", "camera"], or ["cameras"]. Can | |
optionally contain any of ["height", "width", "query_batch_size"] | |
:param params: Meta parameters | |
contains rendering_mode in ["stf", "nerf"] | |
:param options: controls checkpointing, caching, and rendering. | |
Can provide a `rendering_mode` in ["stf", "nerf"] | |
""" | |
params = self.update(params) | |
options = AttrDict() if options is None else AttrDict(options) | |
if options.cache is None: | |
created_cache = True | |
options.cache = AttrDict() | |
else: | |
created_cache = False | |
rendering_mode = options.get("rendering_mode", "stf") | |
# import pdb; pdb.set_trace() | |
if rendering_mode == "nerf": | |
output = render_views_from_rays( | |
self.render_rays, | |
batch, | |
params=params, | |
options=options, | |
device=self.device, | |
patch_size=self.patch_size, | |
use_patch=self.use_patch | |
) | |
elif rendering_mode == "stf": | |
sdf_fn = tf_fn = nerstf_fn = None | |
if self.nerstf is not None: | |
nerstf_fn = partial( | |
self.nerstf.forward_batched, | |
params=subdict(params, "nerstf"), | |
options=options, | |
) | |
else: | |
sdf_fn = partial( | |
self.sdf.forward_batched, | |
params=subdict(params, "sdf"), | |
options=options, | |
) | |
tf_fn = partial( | |
self.tf.forward_batched, | |
params=subdict(params, "tf"), | |
options=options, | |
) | |
output = render_views_from_stf( | |
batch, | |
options, | |
sdf_fn=sdf_fn, | |
tf_fn=tf_fn, | |
nerstf_fn=nerstf_fn, | |
volume=self.volume, | |
grid_size=self.grid_size, | |
channel_scale=self.channel_scale, | |
texture_channels=self.texture_channels, | |
ambient_color=self.ambient_color, | |
diffuse_color=self.diffuse_color, | |
specular_color=self.specular_color, | |
output_srgb=self.output_srgb, | |
device=self.device, | |
) | |
else: | |
raise NotImplementedError | |
if created_cache: | |
del options["cache"] | |
return output | |
def get_signed_distance( | |
self, | |
query: Query, | |
params: Dict[str, torch.Tensor], | |
options: AttrDict[str, Any], | |
) -> torch.Tensor: | |
if self.sdf is not None: | |
return self.sdf(query, params=subdict(params, "sdf"), options=options).signed_distance | |
assert self.nerstf is not None | |
return self.nerstf(query, params=subdict(params, "nerstf"), options=options).signed_distance | |
def get_texture( | |
self, | |
query: Query, | |
params: Dict[str, torch.Tensor], | |
options: AttrDict[str, Any], | |
) -> torch.Tensor: | |
if self.tf is not None: | |
return self.tf(query, params=subdict(params, "tf"), options=options).channels | |
assert self.nerstf is not None | |
return self.nerstf(query, params=subdict(params, "nerstf"), options=options).channels | |