DDCM's picture
initial commit
b273838
from abc import ABC, abstractmethod
import numpy as np
import torch
from util.img_utils import dynamic_thresholding
# ====================
# Model Mean Processor
# ====================
__MODEL_MEAN_PROCESSOR__ = {}
def register_mean_processor(name: str):
def wrapper(cls):
if __MODEL_MEAN_PROCESSOR__.get(name, None):
raise NameError(f"Name {name} is already registerd.")
__MODEL_MEAN_PROCESSOR__[name] = cls
return cls
return wrapper
def get_mean_processor(name: str, **kwargs):
if __MODEL_MEAN_PROCESSOR__.get(name, None) is None:
raise NameError(f"Name {name} is not defined.")
return __MODEL_MEAN_PROCESSOR__[name](**kwargs)
class MeanProcessor(ABC):
"""Predict x_start and calculate mean value"""
@abstractmethod
def __init__(self, betas, dynamic_threshold, clip_denoised):
self.dynamic_threshold = dynamic_threshold
self.clip_denoised = clip_denoised
@abstractmethod
def get_mean_and_xstart(self, x, t, model_output):
pass
def process_xstart(self, x):
if self.dynamic_threshold:
x = dynamic_thresholding(x, s=0.95)
if self.clip_denoised:
x = x.clamp(-1, 1)
return x
@register_mean_processor(name='previous_x')
class PreviousXMeanProcessor(MeanProcessor):
def __init__(self, betas, dynamic_threshold, clip_denoised):
super().__init__(betas, dynamic_threshold, clip_denoised)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0-alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
def predict_xstart(self, x_t, t, x_prev):
coef1 = extract_and_expand(1.0/self.posterior_mean_coef1, t, x_t)
coef2 = extract_and_expand(self.posterior_mean_coef2/self.posterior_mean_coef1, t, x_t)
return coef1 * x_prev - coef2 * x_t
def get_mean_and_xstart(self, x, t, model_output):
mean = model_output
pred_xstart = self.process_xstart(self.predict_xstart(x, t, model_output))
return mean, pred_xstart
@register_mean_processor(name='start_x')
class StartXMeanProcessor(MeanProcessor):
def __init__(self, betas, dynamic_threshold, clip_denoised):
super().__init__(betas, dynamic_threshold, clip_denoised)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0-alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
def q_posterior_mean(self, x_start, x_t, t):
"""
Compute the mean of the diffusion posteriro:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
return coef1 * x_start + coef2 * x_t
def get_mean_and_xstart(self, x, t, model_output):
pred_xstart = self.process_xstart(model_output)
mean = self.q_posterior_mean(x_start=pred_xstart, x_t=x, t=t)
return mean, pred_xstart
@register_mean_processor(name='epsilon')
class EpsilonXMeanProcessor(MeanProcessor):
def __init__(self, betas, dynamic_threshold, clip_denoised):
super().__init__(betas, dynamic_threshold, clip_denoised)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0-alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
def q_posterior_mean(self, x_start, x_t, t):
"""
Compute the mean of the diffusion posteriro:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
return coef1 * x_start + coef2 * x_t
def predict_xstart(self, x_t, t, eps):
coef1 = extract_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t)
coef2 = extract_and_expand(self.sqrt_recipm1_alphas_cumprod, t, eps)
return coef1 * x_t - coef2 * eps
def get_mean_and_xstart(self, x, t, model_output):
pred_xstart = self.process_xstart(self.predict_xstart(x, t, model_output))
mean = self.q_posterior_mean(pred_xstart, x, t)
return mean, pred_xstart
# =========================
# Model Variance Processor
# =========================
__MODEL_VAR_PROCESSOR__ = {}
def register_var_processor(name: str):
def wrapper(cls):
if __MODEL_VAR_PROCESSOR__.get(name, None):
raise NameError(f"Name {name} is already registerd.")
__MODEL_VAR_PROCESSOR__[name] = cls
return cls
return wrapper
def get_var_processor(name: str, **kwargs):
if __MODEL_VAR_PROCESSOR__.get(name, None) is None:
raise NameError(f"Name {name} is not defined.")
return __MODEL_VAR_PROCESSOR__[name](**kwargs)
class VarianceProcessor(ABC):
@abstractmethod
def __init__(self, betas):
pass
@abstractmethod
def get_variance(self, x, t):
pass
@register_var_processor(name='fixed_small')
class FixedSmallVarianceProcessor(VarianceProcessor):
def __init__(self, betas):
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
def get_variance(self, x, t):
model_variance = self.posterior_variance
model_log_variance = np.log(model_variance)
model_variance = extract_and_expand(model_variance, t, x)
model_log_variance = extract_and_expand(model_log_variance, t, x)
return model_variance, model_log_variance
@register_var_processor(name='fixed_large')
class FixedLargeVarianceProcessor(VarianceProcessor):
def __init__(self, betas):
self.betas = betas
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
def get_variance(self, x, t):
model_variance = np.append(self.posterior_variance[1], self.betas[1:])
model_log_variance = np.log(model_variance)
model_variance = extract_and_expand(model_variance, t, x)
model_log_variance = extract_and_expand(model_log_variance, t, x)
return model_variance, model_log_variance
@register_var_processor(name='learned')
class LearnedVarianceProcessor(VarianceProcessor):
def __init__(self, betas):
pass
def get_variance(self, x, t):
model_log_variance = x
model_variance = torch.exp(model_log_variance)
return model_variance, model_log_variance
@register_var_processor(name='learned_range')
class LearnedRangeVarianceProcessor(VarianceProcessor):
def __init__(self, betas):
self.betas = betas
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# log calculation clipped because the posterior variance is 0 at the
# beginning of the diffusion chain.
self.posterior_log_variance_clipped = np.log(
np.append(posterior_variance[1], posterior_variance[1:])
)
def get_variance(self, x, t):
model_var_values = x
min_log = self.posterior_log_variance_clipped
max_log = np.log(self.betas)
min_log = extract_and_expand(min_log, t, x)
max_log = extract_and_expand(max_log, t, x)
# The model_var_values is [-1, 1] for [min_var, max_var]
frac = (model_var_values + 1.0) / 2.0
model_log_variance = frac * max_log + (1-frac) * min_log
model_variance = torch.exp(model_log_variance)
return model_variance, model_log_variance
# ================
# Helper function
# ================
def extract_and_expand(array, time, target):
array = torch.from_numpy(array).to(target.device)[time].float()
while array.ndim < target.ndim:
array = array.unsqueeze(-1)
return array.expand_as(target)
def expand_as(array, target):
if isinstance(array, np.ndarray):
array = torch.from_numpy(array)
elif isinstance(array, np.float):
array = torch.tensor([array])
while array.ndim < target.ndim:
array = array.unsqueeze(-1)
return array.expand_as(target).to(target.device)