File size: 3,220 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from cmath import isnan
import pytest

import torch
from mmcv import Config

from risk_biased.models.nn_blocks import (
    SequenceDecoderLSTM,
    SequenceDecoderMLP,
    SequenceEncoderMaskedLSTM,
    SequenceEncoderMLP,
    AttentionBlock,
)


@pytest.fixture(scope="module")
def params():
    torch.manual_seed(0)
    cfg = Config()
    cfg.batch_size = 4
    cfg.input_dim = 10
    cfg.output_dim = 15
    cfg.latent_dim = 3
    cfg.h_dim = 32
    cfg.num_attention_heads = 4
    cfg.num_h_layers = 2
    cfg.device = "cpu"
    return cfg


def test_AttentionBlock(params):
    attention = AttentionBlock(params.h_dim, params.num_attention_heads)
    num_agents = 4
    num_map_objects = 8
    encoded_agents = torch.rand(params.batch_size, num_agents, params.h_dim)
    mask_agents = torch.rand(params.batch_size, num_agents) > 0.1
    encoded_absolute_agents = torch.rand(params.batch_size, num_agents, params.h_dim)
    encoded_map = torch.rand(params.batch_size, num_map_objects, params.h_dim)
    mask_map = torch.rand(params.batch_size, num_map_objects) > 0.1
    output = attention(
        encoded_agents, mask_agents, encoded_absolute_agents, encoded_map, mask_map
    )
    # check shape
    assert output.shape == (params.batch_size, num_agents, params.h_dim)
    assert not torch.isnan(output).any()


def test_SequenceDecoder(params):
    decoder = SequenceDecoderLSTM(params.h_dim)
    num_agents = 8
    sequence_length = 16

    input = torch.rand(params.batch_size, num_agents, params.h_dim)

    output = decoder(input, sequence_length)

    assert output.shape == (
        params.batch_size,
        num_agents,
        sequence_length,
        params.h_dim,
    )
    assert not torch.isnan(output).any()


def test_SequenceDecoderMLP(params):
    sequence_length = 16
    decoder = SequenceDecoderMLP(
        params.h_dim, params.num_h_layers, sequence_length, True
    )
    num_agents = 8

    input = torch.rand(params.batch_size, num_agents, params.h_dim)

    output = decoder(input, sequence_length)

    assert output.shape == (
        params.batch_size,
        num_agents,
        sequence_length,
        params.h_dim,
    )
    assert not torch.isnan(output).any()


def test_SequenceEncoder(params):
    encoder = SequenceEncoderMaskedLSTM(params.input_dim, params.h_dim)
    num_agents = 8
    sequence_length = 16

    input = torch.rand(params.batch_size, num_agents, sequence_length, params.input_dim)
    mask_input = torch.rand(params.batch_size, num_agents, sequence_length) > 0.1

    output = encoder(input, mask_input)

    assert output.shape == (params.batch_size, num_agents, params.h_dim)
    assert not torch.isnan(output).any()


def test_SequenceEncoderMLP(params):
    sequence_length = 16
    num_agents = 8
    encoder = SequenceEncoderMLP(
        params.input_dim, params.h_dim, params.num_h_layers, sequence_length, True
    )

    input = torch.rand(params.batch_size, num_agents, sequence_length, params.input_dim)
    mask_input = torch.rand(params.batch_size, num_agents, sequence_length) > 0.1

    output = encoder(input, mask_input)

    assert output.shape == (params.batch_size, num_agents, params.h_dim)
    assert not torch.isnan(output).any()