jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
raw
history blame
4.37 kB
from typing import Optional
import torch
from torch import Tensor
from torch.distributions import MultivariateNormal
def reconstruction_loss(
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
):
"""
pred (Tensor): (..., time, [x,y,(a),(vx,vy)])
truth (Tensor): (..., time, [x,y,(a),(vx,vy)])
mask_loss (Tensor): (..., time) Defaults to None.
"""
min_feat_shape = min(pred.shape[-1], truth.shape[-1])
if min_feat_shape == 3:
assert pred.shape[-1] == truth.shape[-1]
return reconstruction_loss(
pred[..., :2], truth[..., :2], mask_loss
) + reconstruction_loss(
torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1),
torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1),
mask_loss,
)
elif min_feat_shape >= 5:
assert pred.shape[-1] <= truth.shape[-1]
v_norm = torch.sum(torch.square(truth[..., 3:5]), -1, keepdim=True)
v_mask = v_norm > 1
return (
reconstruction_loss(pred[..., :2], truth[..., :2], mask_loss)
+ reconstruction_loss(
torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1)
* v_mask,
torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1)
* v_mask,
mask_loss,
)
+ reconstruction_loss(pred[..., 3:5], truth[..., 3:5], mask_loss)
)
elif min_feat_shape == 2:
if mask_loss is None:
return torch.mean(
torch.sqrt(
torch.sum(
torch.square(pred[..., :2] - truth[..., :2]), -1
).clamp_min(1e-6)
)
)
else:
assert mask_loss.any()
mask_loss = mask_loss.float()
return torch.sum(
torch.sqrt(
torch.sum(
torch.square(pred[..., :2] - truth[..., :2]), -1
).clamp_min(1e-6)
)
* mask_loss
) / torch.sum(mask_loss).clamp_min(1)
def map_penalized_reconstruction_loss(
pred: torch.Tensor,
truth: torch.Tensor,
map: torch.Tensor,
mask_map: torch.Tensor,
mask_loss: Optional[torch.Tensor] = None,
map_importance: float = 0.1,
):
"""
pred (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)])
truth (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)])
map (Tensor): (batch_size, num_objects, object_sequence_length, [x, y, ...])
mask_map (Tensor): (...)
mask_loss (Tensor): (..., time) Defaults to None.
"""
# b, a, o, s, f b, a, o, t, s, f
map_distance, _ = (
(map[:, None, :, :, :2] - pred[:, :, None, -1, None, :2])
.square()
.sum(-1)
.min(2)
)
map_distance = map_distance.sqrt().clamp(0.5, 3)
if mask_map is not None:
map_loss = (map_distance * mask_loss[..., -1:]).sum() / mask_loss[..., -1].sum()
else:
map_loss = map_distance.mean()
rec_loss = reconstruction_loss(pred, truth, mask_loss)
return rec_loss + map_importance * map_loss
def cce_loss_with_logits(pred_logits: torch.Tensor, truth: torch.Tensor):
pred_log = pred_logits.log_softmax(-1)
return -(pred_log * truth).sum(-1).mean()
def risk_loss_function(
pred: torch.Tensor,
truth: torch.Tensor,
mask: torch.Tensor,
factor: float = 100.0,
) -> torch.Tensor:
"""
Loss function for the risk comparison. This is assymetric because it is preferred that the model over-estimates
the risk rather than under-estimate it.
Args:
pred: (same_shape) The predicted risks
truth: (same_shape) The reference risks to match
mask: (same_shape) A mask with 1 where the loss should be computed and 0 elsewhere.
approximate_mean_error: An approximation of the mean error obtained after training. The lower this value,
the greater the intensity of the assymetry.
Returns:
Scalar loss value
"""
error = pred - truth
error = error * factor
error = torch.where(error > 1, (error + 1e-6).log(), error.abs())
error = (error * mask).sum() / mask.sum()
return error