File size: 1,175 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
from risk_biased.models.nn_blocks import (
    SequenceEncoderLSTM,
    SequenceEncoderMLP,
    SequenceEncoderMaskedLSTM,
)

from risk_biased.models.cvae_params import CVAEParams
from risk_biased.models.mlp import MLP


class MapEncoderNN(nn.Module):
    """MLP encoder neural network that encodes map objects.

    Args:
        params: dataclass defining the necessary parameters
    """

    def __init__(self, params: CVAEParams) -> None:
        super().__init__()
        self._encoder = SequenceEncoderMLP(
            params.map_state_dim,
            params.hidden_dim,
            params.num_hidden_layers,
            params.max_size_lane,
            params.is_mlp_residual,
        )

    def forward(self, map, mask_map):
        """Forward function encoding map object sequences of features into object features.

        Args:
            map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
            mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask
        """
        encoded_map = self._encoder(map, mask_map)
        return encoded_map