Spaces:
Running
Running
File size: 1,175 Bytes
5769ee4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
import torch.nn as nn
from risk_biased.models.nn_blocks import (
SequenceEncoderLSTM,
SequenceEncoderMLP,
SequenceEncoderMaskedLSTM,
)
from risk_biased.models.cvae_params import CVAEParams
from risk_biased.models.mlp import MLP
class MapEncoderNN(nn.Module):
"""MLP encoder neural network that encodes map objects.
Args:
params: dataclass defining the necessary parameters
"""
def __init__(self, params: CVAEParams) -> None:
super().__init__()
self._encoder = SequenceEncoderMLP(
params.map_state_dim,
params.hidden_dim,
params.num_hidden_layers,
params.max_size_lane,
params.is_mlp_residual,
)
def forward(self, map, mask_map):
"""Forward function encoding map object sequences of features into object features.
Args:
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask
"""
encoded_map = self._encoder(map, mask_map)
return encoded_map
|