Spaces:
Sleeping
Sleeping
File size: 740 Bytes
5769ee4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import pytest
import torch
from mmcv import Config
from risk_biased.models.mlp import MLP
@pytest.fixture(scope="module")
def params():
torch.manual_seed(0)
cfg = Config()
cfg.batch_size = 4
cfg.input_dim = 10
cfg.output_dim = 15
cfg.latent_dim = 3
cfg.h_dim = 64
cfg.num_h_layers = 2
cfg.device = "cpu"
cfg.is_mlp_residual = True
return cfg
def test_mlp(params):
mlp = MLP(
params.input_dim,
params.output_dim,
params.h_dim,
params.num_h_layers,
params.is_mlp_residual,
)
input = torch.rand(params.batch_size, params.input_dim)
output = mlp(input)
# check shape
assert output.shape == (params.batch_size, params.output_dim)
|