Spaces:
Running
Running
from typing import Optional | |
import torch | |
def FDE( | |
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None | |
): | |
""" | |
pred (Tensor): (..., time, xy) | |
truth (Tensor): (..., time, xy) | |
mask_loss (Tensor): (..., time) Defaults to None. | |
""" | |
if mask_loss is None: | |
return torch.mean( | |
torch.sqrt( | |
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1) | |
) | |
) | |
else: | |
mask_loss = mask_loss.float() | |
return torch.sum( | |
torch.sqrt( | |
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1) | |
) | |
* mask_loss[..., -1] | |
) / torch.sum(mask_loss[..., -1]).clamp_min(1) | |
def ADE( | |
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None | |
): | |
""" | |
pred (Tensor): (..., time, xy) | |
truth (Tensor): (..., time, xy) | |
mask_loss (Tensor): (..., time) Defaults to None. | |
""" | |
if mask_loss is None: | |
return torch.mean( | |
torch.sqrt( | |
torch.sum(torch.square(pred[..., :, :2] - truth[..., :, :2]), -1) | |
) | |
) | |
else: | |
mask_loss = mask_loss.float() | |
return torch.sum( | |
torch.sqrt( | |
torch.sum(torch.square(pred[..., :, :2] - truth[..., :, :2]), -1) | |
) | |
* mask_loss | |
) / torch.sum(mask_loss).clamp_min(1) | |
def minFDE( | |
pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None | |
): | |
""" | |
pred (Tensor): (..., n_samples, time, xy) | |
truth (Tensor): (..., time, xy) | |
mask_loss (Tensor): (..., time) Defaults to None. | |
""" | |
if mask_loss is None: | |
min_distances, _ = torch.min( | |
torch.sqrt( | |
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1) | |
), | |
-1, | |
) | |
return torch.mean(min_distances) | |
else: | |
mask_loss = mask_loss[..., -1].float() | |
final_distances = torch.sqrt( | |
torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1) | |
) | |
max_final_distance = torch.max(final_distances * mask_loss) | |
min_distances, _ = torch.min( | |
final_distances + max_final_distance * (1 - mask_loss), -1 | |
) | |
return torch.sum(min_distances * mask_loss.any(-1)) / torch.sum( | |
mask_loss.any(-1) | |
).clamp_min(1) | |