Spaces:
Sleeping
Sleeping
File size: 1,517 Bytes
5769ee4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import os
import pytest
import torch
from mmcv import Config
from risk_biased.models.latent_distributions import GaussianLatentDistribution
@pytest.fixture(scope="module")
def params():
torch.manual_seed(0)
working_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(
working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py"
)
cfg = Config.fromfile(config_path)
cfg.batch_size = 4
cfg.latent_dim = 2
return cfg
@pytest.mark.parametrize("threshold", [(1e-5), (10.0)])
def test_get_kl_loss(params, threshold: float):
z_mean_log_std = torch.rand(params.batch_size, 1, params.latent_dim*2)
distribution = GaussianLatentDistribution(z_mean_log_std)
z_mean, z_log_var = torch.split(z_mean_log_std, params.latent_dim, dim=-1)
z_log_std = z_log_var / 2.0
kl_target = (
(-0.5 * (1.0 + 2.0 * z_log_std - z_mean.square() - (2 * z_log_std).exp()))
.clamp_min(threshold)
).mean()
prior_z_mean_log_std = torch.zeros(params.latent_dim*2)
prior_distribution = GaussianLatentDistribution(prior_z_mean_log_std)
# Test kl loss is 0 on identical distributions
assert torch.isclose(
distribution.kl_loss(distribution, threshold=threshold),
torch.zeros(1), atol=threshold
)
# test kl loss when prior is unit Gaussian
assert torch.isclose(
distribution.kl_loss(prior_distribution, threshold),
kl_target,
)
|