forSubAnony's picture
ade20k
1cae162
"""
Helpers for various likelihood-based losses. These are ported from the original
Ho et al. diffusion models codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
"""
import numpy as np
import torch as th
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, th.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ th.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
)
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = th.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = th.where(
x < -0.999,
log_cdf_plus,
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
def variance_KL_loss(latents, noisy_latents, timesteps, model_pred_mean, model_pred_var, noise_scheduler,posterior_mean_coef1, posterior_mean_coef2, posterior_log_variance_clipped):
model_pred_mean = model_pred_mean.detach()
true_mean = (
posterior_mean_coef1.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * latents
+ posterior_mean_coef2.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * noisy_latents
)
true_log_variance_clipped = posterior_log_variance_clipped.to(device=timesteps.device)[timesteps].float()[
..., None, None, None]
if noise_scheduler.variance_type == "learned":
model_log_variance = model_pred_var
#model_pred_var = th.exp(model_log_variance)
else:
min_log = true_log_variance_clipped
max_log = th.log(noise_scheduler.betas.to(device=timesteps.device)[timesteps].float()[..., None, None, None])
frac = (model_pred_var + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
#model_pred_var = th.exp(model_log_variance)
sqrt_recip_alphas_cumprod = th.sqrt(1.0 / noise_scheduler.alphas_cumprod)
sqrt_recipm1_alphas_cumprod = th.sqrt(1.0 / noise_scheduler.alphas_cumprod - 1)
pred_xstart = (sqrt_recip_alphas_cumprod.to(device=timesteps.device)[timesteps].float()[
..., None, None, None] * noisy_latents
- sqrt_recipm1_alphas_cumprod.to(device=timesteps.device)[timesteps].float()[
..., None, None, None] * model_pred_mean)
model_mean = (
posterior_mean_coef1.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * pred_xstart
+ posterior_mean_coef2.to(device=timesteps.device)[timesteps].float()[..., None, None, None] * noisy_latents
)
# model_mean = out["mean"] model_log_variance = out["log_variance"]
kl = normal_kl(
true_mean, true_log_variance_clipped, model_mean, model_log_variance
)
kl = kl.mean() / np.log(2.0)
decoder_nll = -discretized_gaussian_log_likelihood(
latents, means=model_mean, log_scales=0.5 * model_log_variance
)
assert decoder_nll.shape == latents.shape
decoder_nll = decoder_nll.mean() / np.log(2.0)
# At the first timestep return the decoder NLL,
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
kl_loss = th.where((timesteps == 0), decoder_nll, kl).mean()
return kl_loss
def get_variance(noise_scheduler):
alphas_cumprod_prev = th.cat([th.tensor([1.0]), noise_scheduler.alphas_cumprod[:-1]])
posterior_mean_coef1 = (
noise_scheduler.betas * th.sqrt(alphas_cumprod_prev) / (1.0 - noise_scheduler.alphas_cumprod)
)
posterior_mean_coef2 = (
(1.0 - alphas_cumprod_prev)
* th.sqrt(noise_scheduler.alphas)
/ (1.0 - noise_scheduler.alphas_cumprod)
)
posterior_variance = (
noise_scheduler.betas * (1.0 - alphas_cumprod_prev) / (1.0 - noise_scheduler.alphas_cumprod)
)
posterior_log_variance_clipped = th.log(
th.cat([posterior_variance[1][..., None], posterior_variance[1:]])
)
#res = posterior_log_variance_clipped.to(device=timesteps.device)[timesteps].float()
return posterior_mean_coef1, posterior_mean_coef2, posterior_log_variance_clipped #res[..., None, None, None]