|  | from dataclasses import dataclass | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from einops import reduce | 
					
						
						|  | from jaxtyping import Float | 
					
						
						|  | from torch import Tensor | 
					
						
						|  |  | 
					
						
						|  | from src.dataset.types import BatchedExample | 
					
						
						|  | from src.model.decoder.decoder import DecoderOutput | 
					
						
						|  | from src.model.types import Gaussians | 
					
						
						|  | from .loss import Loss | 
					
						
						|  | from typing import Generic, TypeVar | 
					
						
						|  | from dataclasses import fields | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | import sys | 
					
						
						|  | from pytorch3d.loss import chamfer_distance | 
					
						
						|  | import os | 
					
						
						|  | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | 
					
						
						|  |  | 
					
						
						|  | from src.misc.utils import vis_depth_map | 
					
						
						|  |  | 
					
						
						|  | T_cfg = TypeVar("T_cfg") | 
					
						
						|  | T_wrapper = TypeVar("T_wrapper") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class LossChamferDistanceCfg: | 
					
						
						|  | weight: float | 
					
						
						|  | down_sample_ratio: float | 
					
						
						|  | sigma_image: float | None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class LossChamferDistanceCfgWrapper: | 
					
						
						|  | chamfer_distance: LossChamferDistanceCfg | 
					
						
						|  |  | 
					
						
						|  | class LossChamferDistance(Loss[LossChamferDistanceCfg, LossChamferDistanceCfgWrapper]): | 
					
						
						|  | def __init__(self, cfg: T_wrapper) -> None: | 
					
						
						|  | super().__init__(cfg) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | (field,) = fields(type(cfg)) | 
					
						
						|  | self.cfg = getattr(cfg, field.name) | 
					
						
						|  | self.name = field.name | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | prediction: DecoderOutput, | 
					
						
						|  | batch: BatchedExample, | 
					
						
						|  | gaussians: Gaussians, | 
					
						
						|  | depth_dict: dict, | 
					
						
						|  | global_step: int, | 
					
						
						|  | ) -> Float[Tensor, ""]: | 
					
						
						|  |  | 
					
						
						|  | b, v, h, w, _ = depth_dict['distill_infos']['pts_all'].shape | 
					
						
						|  | pred_pts = depth_dict['distill_infos']['pts_all'].flatten(0, 1) | 
					
						
						|  |  | 
					
						
						|  | conf_mask = depth_dict['distill_infos']['conf_mask'] | 
					
						
						|  | gaussian_meas = gaussians.means | 
					
						
						|  |  | 
					
						
						|  | pred_pts = pred_pts.view(b, v, h, w, -1) | 
					
						
						|  | conf_mask = conf_mask.view(b, v, h, w) | 
					
						
						|  |  | 
					
						
						|  | pts_mask = torch.abs(gaussian_meas[..., -1]) < 1e2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cd_losses = 0.0 | 
					
						
						|  | for b_idx in range(b): | 
					
						
						|  | batch_pts, batch_conf, batch_gaussian = pred_pts[b_idx], conf_mask[b_idx], gaussian_meas[b_idx][pts_mask[b_idx]] | 
					
						
						|  | batch_pts = batch_pts[batch_conf] | 
					
						
						|  | batch_pts = batch_pts[torch.randperm(batch_pts.shape[0])[:int(batch_pts.shape[0] * self.cfg.down_sample_ratio)]] | 
					
						
						|  | batch_gaussian = batch_gaussian[torch.randperm(batch_gaussian.shape[0])[:int(batch_gaussian.shape[0] * self.cfg.down_sample_ratio)]] | 
					
						
						|  | cd_loss = chamfer_distance(batch_pts.unsqueeze(0), batch_gaussian.unsqueeze(0))[0] | 
					
						
						|  | cd_losses = cd_losses + cd_loss | 
					
						
						|  | return self.cfg.weight * torch.nan_to_num(cd_losses / b, nan=0.0) |