File size: 740 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import pytest

import torch
from mmcv import Config

from risk_biased.models.mlp import MLP


@pytest.fixture(scope="module")
def params():
    torch.manual_seed(0)
    cfg = Config()
    cfg.batch_size = 4
    cfg.input_dim = 10
    cfg.output_dim = 15
    cfg.latent_dim = 3
    cfg.h_dim = 64
    cfg.num_h_layers = 2
    cfg.device = "cpu"
    cfg.is_mlp_residual = True
    return cfg


def test_mlp(params):
    mlp = MLP(
        params.input_dim,
        params.output_dim,
        params.h_dim,
        params.num_h_layers,
        params.is_mlp_residual,
    )

    input = torch.rand(params.batch_size, params.input_dim)
    output = mlp(input)
    # check shape
    assert output.shape == (params.batch_size, params.output_dim)