File size: 1,435 Bytes
			
			| 2568013 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | from dataclasses import dataclass
from typing import Literal
from jaxtyping import Float
from torch import Tensor
import torch
import torch.nn.functional as F
from src.dataset.types import BatchedExample
from src.model.decoder.decoder import DecoderOutput
from src.model.types import Gaussians
from .loss import Loss
@dataclass
class LossOpacityCfg:
    weight: float
    type: Literal["exp", "mean", "exp+mean"] = "exp+mean"
@dataclass
class LossOpacityCfgWrapper:
    opacity: LossOpacityCfg
class LossOpacity(Loss[LossOpacityCfg, LossOpacityCfgWrapper]):
    def forward(
        self,
        prediction: DecoderOutput,
        batch: BatchedExample,
        gaussians: Gaussians,
        depth_dict: dict | None,
        global_step: int,
    ) -> Float[Tensor, ""]:
        alpha = prediction.alpha
        valid_mask = batch['context']['valid_mask'].float()
        opacity_loss = F.mse_loss(alpha, valid_mask, reduction='none').mean()
        # if self.cfg.type == "exp":
        #     opacity_loss = torch.exp(-(gaussians.opacities - 0.5) ** 2 / 0.05).mean()
        # elif self.cfg.type == "mean":
        #     opacity_loss = gaussians.opacities.mean()
        # elif self.cfg.type == "exp+mean":
        #     opacity_loss = 0.5 * torch.exp(-(gaussians.opacities - 0.5) ** 2 / 0.05).mean() + gaussians.opacities.mean()
        return self.cfg.weight * torch.nan_to_num(opacity_loss, nan=0.0, posinf=0.0, neginf=0.0)
 | 
