File size: 5,943 Bytes
15acbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""
Helpers for various likelihood-based losses. These are ported from the original
Ho et al. diffusion models codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
"""
import numpy as np
import torch as th
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, th.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
def variance_KL_loss(latents, noisy_latents, timesteps, model_pred_mean, model_pred_var, noise_scheduler,posterior_mean_coef1, posterior_mean_coef2, posterior_log_variance_clipped):
model_pred_mean = model_pred_mean.detach()
true_mean = (
posterior_mean_coef1.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * latents
+ posterior_mean_coef2.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * noisy_latents
)
true_log_variance_clipped = posterior_log_variance_clipped.to(device=timesteps.device)[timesteps].float()[
..., None, None, None]
if noise_scheduler.variance_type == "learned":
model_log_variance = model_pred_var
#model_pred_var = th.exp(model_log_variance)
else:
min_log = true_log_variance_clipped
max_log = th.log(noise_scheduler.betas.to(device=timesteps.device)[timesteps].float()[..., None, None, None])
frac = (model_pred_var + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
#model_pred_var = th.exp(model_log_variance)
sqrt_recip_alphas_cumprod = th.sqrt(1.0 / noise_scheduler.alphas_cumprod)
sqrt_recipm1_alphas_cumprod = th.sqrt(1.0 / noise_scheduler.alphas_cumprod - 1)
pred_xstart = (sqrt_recip_alphas_cumprod.to(device=timesteps.device)[timesteps].float()[
..., None, None, None] * noisy_latents
- sqrt_recipm1_alphas_cumprod.to(device=timesteps.device)[timesteps].float()[
..., None, None, None] * model_pred_mean)
model_mean = (
posterior_mean_coef1.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * pred_xstart
+ posterior_mean_coef2.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * noisy_latents
)
# model_mean = out["mean"] model_log_variance = out["log_variance"]
kl = normal_kl(
true_mean, true_log_variance_clipped, model_mean, model_log_variance
)
kl = kl.mean() / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
latents, means=model_mean, log_scales=0.5 * model_log_variance
)
assert decoder_nll.shape == latents.shape
decoder_nll = decoder_nll.mean() / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
kl_loss = th.where((timesteps == 0), decoder_nll, kl).mean()
return kl_loss
def get_variance(noise_scheduler):
alphas_cumprod_prev = th.cat([th.tensor([1.0]), noise_scheduler.alphas_cumprod[:-1]])
posterior_mean_coef1 = (
noise_scheduler.betas * th.sqrt(alphas_cumprod_prev) / (1.0 - noise_scheduler.alphas_cumprod)
)
posterior_mean_coef2 = (
(1.0 - alphas_cumprod_prev)
* th.sqrt(noise_scheduler.alphas)
/ (1.0 - noise_scheduler.alphas_cumprod)
)
posterior_variance = (
noise_scheduler.betas * (1.0 - alphas_cumprod_prev) / (1.0 - noise_scheduler.alphas_cumprod)
)
posterior_log_variance_clipped = th.log(
th.cat([posterior_variance[1][..., None], posterior_variance[1:]])
)
#res = posterior_log_variance_clipped.to(device=timesteps.device)[timesteps].float()
return posterior_mean_coef1, posterior_mean_coef2, posterior_log_variance_clipped #res[..., None, None, None]
|