3v324v23's picture
code pushed
515f781
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
def append_dims(x, target_dims):
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def default_noise_sampler(x):
return lambda sigma, sigma_next: torch.randn_like(x)
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
if not eta:
return sigma_to, 0.
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
def to_d(x, sigma, denoised):
return (x - denoised) / append_dims(sigma, x.ndim)
class Sampler(object):
def __init__(self, net, type="ddim", steps=50, output_dim=[512, 512], n_samples=4, scale=7.5):
super().__init__()
self.net = net
self.type = type
self.steps = steps
self.output_dim = output_dim
self.n_samples = n_samples
self.scale = scale
self.sigmas = ((1 - net.alphas_cumprod) / net.alphas_cumprod) ** 0.5
self.log_sigmas = self.sigmas.log()
def t_to_sigma(self, t):
t = t.float()
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def get_sigmas(self, n=None):
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
if n is None:
return append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return append_zero(self.t_to_sigma(t))
@torch.no_grad()
def sample(self, x_info, c_info):
h, w = self.output_dim
shape = [self.n_samples, 4, h//8, w//8]
device, dtype = self.net.get_device(), self.net.get_dtype()
if ('xt' in x_info) and (x_info['xt'] is not None):
xt = x_info['xt'].astype(dtype).to(device)
x_info['x'] = xt
elif ('x0' in x_info) and (x_info['x0'] is not None):
x0 = x_info['x0'].type(dtype).to(device)
ts = timesteps[x_info['x0_forward_timesteps']].repeat(self.n_samples)
ts = torch.Tensor(ts).long().to(device)
timesteps = timesteps[:x_info['x0_forward_timesteps']]
x0_nz = self.model.q_sample(x0, ts)
x_info['x'] = x0_nz
else:
x_info['x'] = torch.randn(shape, device=device, dtype=dtype)
sigmas = self.get_sigmas(n=self.steps)
if self.type == 'eular_a':
rv = self.sample_euler_ancestral(
x_info=x_info,
c_info=c_info,
sigmas = sigmas)
return rv
@torch.no_grad()
def sample_euler_ancestral(
self, x_info, c_info, sigmas, eta=1., s_noise=1.,):
x = x_info['x']
x = x * sigmas[0]
noise_sampler = default_noise_sampler(x)
s_in = x.new_ones([x.shape[0]])
for i in range(len(sigmas)-1):
denoised = self.net.apply_model(x, sigmas[i] * s_in, )
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x