File size: 1,517 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import pytest

import torch
from mmcv import Config

from risk_biased.models.latent_distributions import GaussianLatentDistribution


@pytest.fixture(scope="module")
def params():
    torch.manual_seed(0)
    working_dir = os.path.dirname(os.path.realpath(__file__))
    config_path = os.path.join(
        working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py"
    )
    cfg = Config.fromfile(config_path)
    cfg.batch_size = 4
    cfg.latent_dim = 2
    return cfg


@pytest.mark.parametrize("threshold", [(1e-5), (10.0)])
def test_get_kl_loss(params, threshold: float):
    z_mean_log_std = torch.rand(params.batch_size, 1, params.latent_dim*2)
    
    distribution = GaussianLatentDistribution(z_mean_log_std)
    
    z_mean, z_log_var = torch.split(z_mean_log_std, params.latent_dim, dim=-1)
    z_log_std = z_log_var / 2.0

    kl_target = (
            (-0.5 * (1.0 + 2.0 * z_log_std - z_mean.square() - (2 * z_log_std).exp()))
            .clamp_min(threshold)
        ).mean()
        

    prior_z_mean_log_std = torch.zeros(params.latent_dim*2)
    prior_distribution = GaussianLatentDistribution(prior_z_mean_log_std)

    # Test kl loss is 0 on identical distributions
    assert torch.isclose(
        distribution.kl_loss(distribution, threshold=threshold),
        torch.zeros(1), atol=threshold
    )
    
    # test kl loss when prior is unit Gaussian
    assert torch.isclose(
        distribution.kl_loss(prior_distribution, threshold),
        kl_target,
    )