File size: 5,527 Bytes
19c4ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from typing import Any, Callable, Dict, Optional, List

import torch
import torch.nn as nn

from .gaussian_diffusion import GaussianDiffusion
from .k_diffusion import karras_sample, karras_sample_addition_condition

DEFAULT_KARRAS_STEPS = 64
DEFAULT_KARRAS_SIGMA_MIN = 1e-3
DEFAULT_KARRAS_SIGMA_MAX = 160
DEFAULT_KARRAS_S_CHURN = 0.0


def uncond_guide_model(
    model: Callable[..., torch.Tensor], scale: float
) -> Callable[..., torch.Tensor]:

    def model_fn(x_t, ts, **kwargs):
        half = x_t[: len(x_t) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = model(combined, ts, **kwargs)
        cond_out, uncond_out = torch.chunk(model_out, 2, dim=0)
        cond_out = uncond_out + scale * (cond_out - uncond_out)
        return torch.cat([cond_out, cond_out], dim=0)

    return model_fn


def sample_latents(
    *,
    batch_size: int,
    model: nn.Module,
    diffusion: GaussianDiffusion,
    model_kwargs: Dict[str, Any],
    guidance_scale: float,
    clip_denoised: bool,
    use_fp16: bool,
    use_karras: bool,
    karras_steps: int,
    sigma_min: float,
    sigma_max: float,
    s_churn: float,
    device: Optional[torch.device] = None,
    progress: bool = False,
    initial_noise: Optional[torch.Tensor] = None,
) -> (torch.Tensor, List[torch.Tensor]):
    sample_shape = (batch_size, model.d_latent)

    if device is None:
        device = next(model.parameters()).device

    if hasattr(model, "cached_model_kwargs"):
        model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
    if guidance_scale != 1.0 and guidance_scale != 0.0:
        for k, v in model_kwargs.copy().items():
            # print(k, v.shape)
            model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)

    sample_shape = (batch_size, model.d_latent)
    with torch.autocast(device_type=device.type, enabled=use_fp16):
        if use_karras:
            samples, sample_sequence = karras_sample(
                diffusion=diffusion,
                model=model,
                shape=sample_shape,
                steps=karras_steps,
                clip_denoised=clip_denoised,
                model_kwargs=model_kwargs,
                device=device,
                sigma_min=sigma_min,
                sigma_max=sigma_max,
                s_churn=s_churn,
                guidance_scale=guidance_scale,
                progress=progress,
                initial_noise=initial_noise,
            )
        else:
            internal_batch_size = batch_size
            if guidance_scale != 1.0:
                model = uncond_guide_model(model, guidance_scale)
                internal_batch_size *= 2
            samples = diffusion.p_sample_loop(
                model,
                shape=(internal_batch_size, *sample_shape[1:]),
                model_kwargs=model_kwargs,
                device=device,
                clip_denoised=clip_denoised,
                progress=progress,
            )

    return samples


def sample_latents_with_additional_latent(
    *,
    batch_size: int,
    model: nn.Module,
    diffusion: GaussianDiffusion,
    model_kwargs: Dict[str, Any],
    text_guidance_scale: float,
    image_guidance_scale: float,
    clip_denoised: bool,
    use_fp16: bool,
    use_karras: bool,
    karras_steps: int,
    sigma_min: float,
    sigma_max: float,
    s_churn: float,
    device: Optional[torch.device] = None,
    progress: bool = False,
    condition_latent: Optional[torch.Tensor] = None,
) -> (torch.Tensor, List[torch.Tensor]):

    if device is None:
        device = next(model.parameters()).device

    if hasattr(model, "cached_model_kwargs"):
        model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
    if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0):
        for k, v in model_kwargs.copy().items():
            # print(k, v.shape)
            model_kwargs[k] = torch.cat([v, torch.zeros_like(v), torch.zeros_like(v)], dim=0)
            condition_latent = torch.cat([condition_latent, condition_latent, torch.zeros_like(condition_latent)], dim=0)

    sample_shape = (batch_size, model.d_latent)
    # print("sample_shape", sample_shape)
    with torch.autocast(device_type=device.type, enabled=use_fp16):
        if use_karras:
            samples, samples_squence = karras_sample_addition_condition(
                diffusion=diffusion,
                model=model,
                shape=sample_shape,
                steps=karras_steps,
                clip_denoised=clip_denoised,
                model_kwargs=model_kwargs,
                device=device,
                sigma_min=sigma_min,
                sigma_max=sigma_max,
                s_churn=s_churn,
                text_guidance_scale=text_guidance_scale,
                image_guidance_scale=image_guidance_scale,
                progress=progress,
                condition_latent=condition_latent,
            )
        else:
            internal_batch_size = batch_size
            if text_guidance_scale != 1.0:
                model = uncond_guide_model(model, text_guidance_scale)
                internal_batch_size *= 2
            samples = diffusion.p_sample_loop(
                model,
                shape=(internal_batch_size, *sample_shape[1:]),
                model_kwargs=model_kwargs,
                device=device,
                clip_denoised=clip_denoised,
                progress=progress,
            )

    return samples