File size: 4,368 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import Optional

import torch
from torch import Tensor
from torch.distributions import MultivariateNormal


def reconstruction_loss(
    pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
):
    """
    pred (Tensor): (..., time, [x,y,(a),(vx,vy)])
    truth (Tensor): (..., time, [x,y,(a),(vx,vy)])
    mask_loss (Tensor): (..., time) Defaults to None.
    """
    min_feat_shape = min(pred.shape[-1], truth.shape[-1])
    if min_feat_shape == 3:
        assert pred.shape[-1] == truth.shape[-1]
        return reconstruction_loss(
            pred[..., :2], truth[..., :2], mask_loss
        ) + reconstruction_loss(
            torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1),
            torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1),
            mask_loss,
        )
    elif min_feat_shape >= 5:
        assert pred.shape[-1] <= truth.shape[-1]
        v_norm = torch.sum(torch.square(truth[..., 3:5]), -1, keepdim=True)
        v_mask = v_norm > 1
        return (
            reconstruction_loss(pred[..., :2], truth[..., :2], mask_loss)
            + reconstruction_loss(
                torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1)
                * v_mask,
                torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1)
                * v_mask,
                mask_loss,
            )
            + reconstruction_loss(pred[..., 3:5], truth[..., 3:5], mask_loss)
        )
    elif min_feat_shape == 2:
        if mask_loss is None:
            return torch.mean(
                torch.sqrt(
                    torch.sum(
                        torch.square(pred[..., :2] - truth[..., :2]), -1
                    ).clamp_min(1e-6)
                )
            )
        else:
            assert mask_loss.any()
            mask_loss = mask_loss.float()
            return torch.sum(
                torch.sqrt(
                    torch.sum(
                        torch.square(pred[..., :2] - truth[..., :2]), -1
                    ).clamp_min(1e-6)
                )
                * mask_loss
            ) / torch.sum(mask_loss).clamp_min(1)


def map_penalized_reconstruction_loss(
    pred: torch.Tensor,
    truth: torch.Tensor,
    map: torch.Tensor,
    mask_map: torch.Tensor,
    mask_loss: Optional[torch.Tensor] = None,
    map_importance: float = 0.1,
):
    """
    pred (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)])
    truth (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)])
    map (Tensor): (batch_size, num_objects, object_sequence_length, [x, y, ...])
    mask_map (Tensor): (...)
    mask_loss (Tensor): (..., time) Defaults to None.

    """
    #        b,  a,   o,  s,  f         b, a,  o,   t,  s,    f
    map_distance, _ = (
        (map[:, None, :, :, :2] - pred[:, :, None, -1, None, :2])
        .square()
        .sum(-1)
        .min(2)
    )
    map_distance = map_distance.sqrt().clamp(0.5, 3)
    if mask_map is not None:
        map_loss = (map_distance * mask_loss[..., -1:]).sum() / mask_loss[..., -1].sum()
    else:
        map_loss = map_distance.mean()

    rec_loss = reconstruction_loss(pred, truth, mask_loss)

    return rec_loss + map_importance * map_loss


def cce_loss_with_logits(pred_logits: torch.Tensor, truth: torch.Tensor):
    pred_log = pred_logits.log_softmax(-1)
    return -(pred_log * truth).sum(-1).mean()


def risk_loss_function(
    pred: torch.Tensor,
    truth: torch.Tensor,
    mask: torch.Tensor,
    factor: float = 100.0,
) -> torch.Tensor:
    """
    Loss function for the risk comparison. This is assymetric because it is preferred that the model over-estimates
    the risk rather than under-estimate it.
    Args:
        pred: (same_shape) The predicted risks
        truth: (same_shape) The reference risks to match
        mask: (same_shape) A mask with 1 where the loss should be computed and 0 elsewhere.
        approximate_mean_error: An approximation of the mean error obtained after training. The lower this value,
            the greater the intensity of the assymetry.
    Returns:
        Scalar loss value
    """
    error = pred - truth
    error = error * factor
    error = torch.where(error > 1, (error + 1e-6).log(), error.abs())
    error = (error * mask).sum() / mask.sum()
    return error