Spaces:
Runtime error
Runtime error
"""SAMPLING ONLY.""" | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from functools import partial | |
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like | |
def append_dims(x, target_dims): | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | |
return x[(...,) + (None,) * dims_to_append] | |
def default_noise_sampler(x): | |
return lambda sigma, sigma_next: torch.randn_like(x) | |
def get_ancestral_step(sigma_from, sigma_to, eta=1.): | |
if not eta: | |
return sigma_to, 0. | |
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) | |
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | |
return sigma_down, sigma_up | |
def to_d(x, sigma, denoised): | |
return (x - denoised) / append_dims(sigma, x.ndim) | |
class Sampler(object): | |
def __init__(self, net, type="ddim", steps=50, output_dim=[512, 512], n_samples=4, scale=7.5): | |
super().__init__() | |
self.net = net | |
self.type = type | |
self.steps = steps | |
self.output_dim = output_dim | |
self.n_samples = n_samples | |
self.scale = scale | |
self.sigmas = ((1 - net.alphas_cumprod) / net.alphas_cumprod) ** 0.5 | |
self.log_sigmas = self.sigmas.log() | |
def t_to_sigma(self, t): | |
t = t.float() | |
low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | |
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] | |
return log_sigma.exp() | |
def get_sigmas(self, n=None): | |
def append_zero(x): | |
return torch.cat([x, x.new_zeros([1])]) | |
if n is None: | |
return append_zero(self.sigmas.flip(0)) | |
t_max = len(self.sigmas) - 1 | |
t = torch.linspace(t_max, 0, n, device=self.sigmas.device) | |
return append_zero(self.t_to_sigma(t)) | |
def sample(self, x_info, c_info): | |
h, w = self.output_dim | |
shape = [self.n_samples, 4, h//8, w//8] | |
device, dtype = self.net.get_device(), self.net.get_dtype() | |
if ('xt' in x_info) and (x_info['xt'] is not None): | |
xt = x_info['xt'].astype(dtype).to(device) | |
x_info['x'] = xt | |
elif ('x0' in x_info) and (x_info['x0'] is not None): | |
x0 = x_info['x0'].type(dtype).to(device) | |
ts = timesteps[x_info['x0_forward_timesteps']].repeat(self.n_samples) | |
ts = torch.Tensor(ts).long().to(device) | |
timesteps = timesteps[:x_info['x0_forward_timesteps']] | |
x0_nz = self.model.q_sample(x0, ts) | |
x_info['x'] = x0_nz | |
else: | |
x_info['x'] = torch.randn(shape, device=device, dtype=dtype) | |
sigmas = self.get_sigmas(n=self.steps) | |
if self.type == 'eular_a': | |
rv = self.sample_euler_ancestral( | |
x_info=x_info, | |
c_info=c_info, | |
sigmas = sigmas) | |
return rv | |
def sample_euler_ancestral( | |
self, x_info, c_info, sigmas, eta=1., s_noise=1.,): | |
x = x_info['x'] | |
x = x * sigmas[0] | |
noise_sampler = default_noise_sampler(x) | |
s_in = x.new_ones([x.shape[0]]) | |
for i in range(len(sigmas)-1): | |
denoised = self.net.apply_model(x, sigmas[i] * s_in, ) | |
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) | |
d = to_d(x, sigmas[i], denoised) | |
# Euler method | |
dt = sigma_down - sigmas[i] | |
x = x + d * dt | |
if sigmas[i + 1] > 0: | |
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up | |
return x | |