jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
from typing import Optional
import torch
from torch import Tensor
from torch.distributions import MultivariateNormal
def reconstruction_loss(
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
):
"""
pred (Tensor): (..., time, [x,y,(a),(vx,vy)])
truth (Tensor): (..., time, [x,y,(a),(vx,vy)])
mask_loss (Tensor): (..., time) Defaults to None.
"""
min_feat_shape = min(pred.shape[-1], truth.shape[-1])
if min_feat_shape == 3:
assert pred.shape[-1] == truth.shape[-1]
return reconstruction_loss(
pred[..., :2], truth[..., :2], mask_loss
) + reconstruction_loss(
torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1),
torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1),
mask_loss,
)
elif min_feat_shape >= 5:
assert pred.shape[-1] <= truth.shape[-1]
v_norm = torch.sum(torch.square(truth[..., 3:5]), -1, keepdim=True)
v_mask = v_norm > 1
return (
reconstruction_loss(pred[..., :2], truth[..., :2], mask_loss)
+ reconstruction_loss(
torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1)
* v_mask,
torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1)
* v_mask,
mask_loss,
)
+ reconstruction_loss(pred[..., 3:5], truth[..., 3:5], mask_loss)
)
elif min_feat_shape == 2:
if mask_loss is None:
return torch.mean(
torch.sqrt(
torch.sum(
torch.square(pred[..., :2] - truth[..., :2]), -1
).clamp_min(1e-6)
)
)
else:
assert mask_loss.any()
mask_loss = mask_loss.float()
return torch.sum(
torch.sqrt(
torch.sum(
torch.square(pred[..., :2] - truth[..., :2]), -1
).clamp_min(1e-6)
)
* mask_loss
) / torch.sum(mask_loss).clamp_min(1)
def map_penalized_reconstruction_loss(
pred: torch.Tensor,
truth: torch.Tensor,
map: torch.Tensor,
mask_map: torch.Tensor,
mask_loss: Optional[torch.Tensor] = None,
map_importance: float = 0.1,
):
"""
pred (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)])
truth (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)])
map (Tensor): (batch_size, num_objects, object_sequence_length, [x, y, ...])
mask_map (Tensor): (...)
mask_loss (Tensor): (..., time) Defaults to None.
"""
# b, a, o, s, f b, a, o, t, s, f
map_distance, _ = (
(map[:, None, :, :, :2] - pred[:, :, None, -1, None, :2])
.square()
.sum(-1)
.min(2)
)
map_distance = map_distance.sqrt().clamp(0.5, 3)
if mask_map is not None:
map_loss = (map_distance * mask_loss[..., -1:]).sum() / mask_loss[..., -1].sum()
else:
map_loss = map_distance.mean()
rec_loss = reconstruction_loss(pred, truth, mask_loss)
return rec_loss + map_importance * map_loss
def cce_loss_with_logits(pred_logits: torch.Tensor, truth: torch.Tensor):
pred_log = pred_logits.log_softmax(-1)
return -(pred_log * truth).sum(-1).mean()
def risk_loss_function(
pred: torch.Tensor,
truth: torch.Tensor,
mask: torch.Tensor,
factor: float = 100.0,
) -> torch.Tensor:
"""
Loss function for the risk comparison. This is assymetric because it is preferred that the model over-estimates
the risk rather than under-estimate it.
Args:
pred: (same_shape) The predicted risks
truth: (same_shape) The reference risks to match
mask: (same_shape) A mask with 1 where the loss should be computed and 0 elsewhere.
approximate_mean_error: An approximation of the mean error obtained after training. The lower this value,
the greater the intensity of the assymetry.
Returns:
Scalar loss value
"""
error = pred - truth
error = error * factor
error = torch.where(error > 1, (error + 1e-6).log(), error.abs())
error = (error * mask).sum() / mask.sum()
return error