Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from mmcv import Config | |
from risk_biased.models.mlp import MLP | |
def params(): | |
torch.manual_seed(0) | |
cfg = Config() | |
cfg.batch_size = 4 | |
cfg.input_dim = 10 | |
cfg.output_dim = 15 | |
cfg.latent_dim = 3 | |
cfg.h_dim = 64 | |
cfg.num_h_layers = 2 | |
cfg.device = "cpu" | |
cfg.is_mlp_residual = True | |
return cfg | |
def test_mlp(params): | |
mlp = MLP( | |
params.input_dim, | |
params.output_dim, | |
params.h_dim, | |
params.num_h_layers, | |
params.is_mlp_residual, | |
) | |
input = torch.rand(params.batch_size, params.input_dim) | |
output = mlp(input) | |
# check shape | |
assert output.shape == (params.batch_size, params.output_dim) | |