from typing import Optional import torch from torch import Tensor from torch.distributions import MultivariateNormal def reconstruction_loss( pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None ): """ pred (Tensor): (..., time, [x,y,(a),(vx,vy)]) truth (Tensor): (..., time, [x,y,(a),(vx,vy)]) mask_loss (Tensor): (..., time) Defaults to None. """ min_feat_shape = min(pred.shape[-1], truth.shape[-1]) if min_feat_shape == 3: assert pred.shape[-1] == truth.shape[-1] return reconstruction_loss( pred[..., :2], truth[..., :2], mask_loss ) + reconstruction_loss( torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1), torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1), mask_loss, ) elif min_feat_shape >= 5: assert pred.shape[-1] <= truth.shape[-1] v_norm = torch.sum(torch.square(truth[..., 3:5]), -1, keepdim=True) v_mask = v_norm > 1 return ( reconstruction_loss(pred[..., :2], truth[..., :2], mask_loss) + reconstruction_loss( torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1) * v_mask, torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1) * v_mask, mask_loss, ) + reconstruction_loss(pred[..., 3:5], truth[..., 3:5], mask_loss) ) elif min_feat_shape == 2: if mask_loss is None: return torch.mean( torch.sqrt( torch.sum( torch.square(pred[..., :2] - truth[..., :2]), -1 ).clamp_min(1e-6) ) ) else: assert mask_loss.any() mask_loss = mask_loss.float() return torch.sum( torch.sqrt( torch.sum( torch.square(pred[..., :2] - truth[..., :2]), -1 ).clamp_min(1e-6) ) * mask_loss ) / torch.sum(mask_loss).clamp_min(1) def map_penalized_reconstruction_loss( pred: torch.Tensor, truth: torch.Tensor, map: torch.Tensor, mask_map: torch.Tensor, mask_loss: Optional[torch.Tensor] = None, map_importance: float = 0.1, ): """ pred (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)]) truth (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)]) map (Tensor): (batch_size, num_objects, object_sequence_length, [x, y, ...]) mask_map (Tensor): (...) mask_loss (Tensor): (..., time) Defaults to None. """ # b, a, o, s, f b, a, o, t, s, f map_distance, _ = ( (map[:, None, :, :, :2] - pred[:, :, None, -1, None, :2]) .square() .sum(-1) .min(2) ) map_distance = map_distance.sqrt().clamp(0.5, 3) if mask_map is not None: map_loss = (map_distance * mask_loss[..., -1:]).sum() / mask_loss[..., -1].sum() else: map_loss = map_distance.mean() rec_loss = reconstruction_loss(pred, truth, mask_loss) return rec_loss + map_importance * map_loss def cce_loss_with_logits(pred_logits: torch.Tensor, truth: torch.Tensor): pred_log = pred_logits.log_softmax(-1) return -(pred_log * truth).sum(-1).mean() def risk_loss_function( pred: torch.Tensor, truth: torch.Tensor, mask: torch.Tensor, factor: float = 100.0, ) -> torch.Tensor: """ Loss function for the risk comparison. This is assymetric because it is preferred that the model over-estimates the risk rather than under-estimate it. Args: pred: (same_shape) The predicted risks truth: (same_shape) The reference risks to match mask: (same_shape) A mask with 1 where the loss should be computed and 0 elsewhere. approximate_mean_error: An approximation of the mean error obtained after training. The lower this value, the greater the intensity of the assymetry. Returns: Scalar loss value """ error = pred - truth error = error * factor error = torch.where(error > 1, (error + 1e-6).log(), error.abs()) error = (error * mask).sum() / mask.sum() return error