File size: 3,768 Bytes
515f781
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""SAMPLING ONLY."""

import torch
import numpy as np
from tqdm import tqdm
from functools import partial

from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like

def append_dims(x, target_dims):
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
    return x[(...,) + (None,) * dims_to_append]

def default_noise_sampler(x):
    return lambda sigma, sigma_next: torch.randn_like(x)

def get_ancestral_step(sigma_from, sigma_to, eta=1.):
    if not eta:
        return sigma_to, 0.
    sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
    sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
    return sigma_down, sigma_up

def to_d(x, sigma, denoised):
    return (x - denoised) / append_dims(sigma, x.ndim)

class Sampler(object):
    def __init__(self, net, type="ddim", steps=50, output_dim=[512, 512], n_samples=4, scale=7.5):
        super().__init__()
        self.net = net
        self.type = type
        self.steps = steps
        self.output_dim = output_dim
        self.n_samples = n_samples
        self.scale = scale
        self.sigmas = ((1 - net.alphas_cumprod) / net.alphas_cumprod) ** 0.5
        self.log_sigmas = self.sigmas.log()

    def t_to_sigma(self, t):
        t = t.float()
        low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
        log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
        return log_sigma.exp()

    def get_sigmas(self, n=None):
        def append_zero(x):
            return torch.cat([x, x.new_zeros([1])])
        if n is None:
            return append_zero(self.sigmas.flip(0))
        t_max = len(self.sigmas) - 1
        t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
        return append_zero(self.t_to_sigma(t))

    @torch.no_grad()
    def sample(self, x_info, c_info):
        h, w = self.output_dim
        shape = [self.n_samples, 4, h//8, w//8]
        device, dtype = self.net.get_device(), self.net.get_dtype()

        if ('xt' in x_info) and (x_info['xt'] is not None):
            xt = x_info['xt'].astype(dtype).to(device)
            x_info['x'] = xt
        elif ('x0' in x_info) and (x_info['x0'] is not None):
            x0 = x_info['x0'].type(dtype).to(device)
            ts = timesteps[x_info['x0_forward_timesteps']].repeat(self.n_samples)
            ts = torch.Tensor(ts).long().to(device)
            timesteps = timesteps[:x_info['x0_forward_timesteps']]
            x0_nz = self.model.q_sample(x0, ts)
            x_info['x'] = x0_nz
        else:
            x_info['x'] = torch.randn(shape, device=device, dtype=dtype)

        sigmas = self.get_sigmas(n=self.steps)

        if self.type == 'eular_a':
            rv = self.sample_euler_ancestral(
                x_info=x_info,
                c_info=c_info,
                sigmas = sigmas)
            return rv

    @torch.no_grad()
    def sample_euler_ancestral(
            self, x_info, c_info, sigmas, eta=1., s_noise=1.,):

        x = x_info['x']
        x = x * sigmas[0]

        noise_sampler = default_noise_sampler(x)

        s_in = x.new_ones([x.shape[0]])
        for i in range(len(sigmas)-1):
            denoised = self.net.apply_model(x, sigmas[i] * s_in, )

            sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
            d = to_d(x, sigmas[i], denoised)
            # Euler method
            dt = sigma_down - sigmas[i]
            x = x + d * dt
            if sigmas[i + 1] > 0:
                x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
        return x