JackyZhuo commited on
Commit
621f28d
·
verified ·
1 Parent(s): ee9b9b9

Upload folder using huggingface_hub

Browse files
transport/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transport import ModelType, PathType, Sampler, Transport, WeightType
2
+
3
+
4
+ def create_transport(
5
+ path_type="Linear",
6
+ prediction="velocity",
7
+ loss_weight=None,
8
+ train_eps=None,
9
+ sample_eps=None,
10
+ snr_type="uniform",
11
+ do_shift=True,
12
+ seq_len=1024, # corresponding to 512x512
13
+ ):
14
+ """function for creating Transport object
15
+ **Note**: model prediction defaults to velocity
16
+ Args:
17
+ - path_type: type of path to use; default to linear
18
+ - learn_score: set model prediction to score
19
+ - learn_noise: set model prediction to noise
20
+ - velocity_weighted: weight loss by velocity weight
21
+ - likelihood_weighted: weight loss by likelihood weight
22
+ - train_eps: small epsilon for avoiding instability during training
23
+ - sample_eps: small epsilon for avoiding instability during sampling
24
+ """
25
+
26
+ if prediction == "noise":
27
+ model_type = ModelType.NOISE
28
+ elif prediction == "score":
29
+ model_type = ModelType.SCORE
30
+ else:
31
+ model_type = ModelType.VELOCITY
32
+
33
+ if loss_weight == "velocity":
34
+ loss_type = WeightType.VELOCITY
35
+ elif loss_weight == "likelihood":
36
+ loss_type = WeightType.LIKELIHOOD
37
+ else:
38
+ loss_type = WeightType.NONE
39
+
40
+ path_choice = {
41
+ "Linear": PathType.LINEAR,
42
+ "GVP": PathType.GVP,
43
+ "VP": PathType.VP,
44
+ }
45
+
46
+ path_type = path_choice[path_type]
47
+
48
+ if path_type in [PathType.VP]:
49
+ train_eps = 1e-5 if train_eps is None else train_eps
50
+ sample_eps = 1e-3 if train_eps is None else sample_eps
51
+ elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
52
+ train_eps = 1e-3 if train_eps is None else train_eps
53
+ sample_eps = 1e-3 if train_eps is None else sample_eps
54
+ else: # velocity & [GVP, LINEAR] is stable everywhere
55
+ train_eps = 0
56
+ sample_eps = 0
57
+
58
+ # create flow state
59
+ state = Transport(
60
+ model_type=model_type,
61
+ path_type=path_type,
62
+ loss_type=loss_type,
63
+ train_eps=train_eps,
64
+ sample_eps=sample_eps,
65
+ snr_type=snr_type,
66
+ do_shift=do_shift,
67
+ seq_len=seq_len,
68
+ )
69
+
70
+ return state
transport/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.63 kB). View file
 
transport/__pycache__/dpm_solver.cpython-310.pyc ADDED
Binary file (50.9 kB). View file
 
transport/__pycache__/integrators.cpython-310.pyc ADDED
Binary file (3.78 kB). View file
 
transport/__pycache__/path.cpython-310.pyc ADDED
Binary file (8.35 kB). View file
 
transport/__pycache__/transport.cpython-310.pyc ADDED
Binary file (15.6 kB). View file
 
transport/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.26 kB). View file
 
transport/dpm_solver.py ADDED
@@ -0,0 +1,1386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
18
+ import os
19
+
20
+ import torch
21
+ from tqdm import tqdm
22
+
23
+
24
+ class NoiseScheduleFlow:
25
+ def __init__(
26
+ self,
27
+ schedule="discrete_flow",
28
+ ):
29
+ """Create a wrapper class for the forward SDE (EDM type)."""
30
+ self.T = 1
31
+ self.t0 = 0.001
32
+ self.schedule = schedule # ['continuous', 'discrete_flow']
33
+ self.total_N = 1000
34
+
35
+ def marginal_log_mean_coeff(self, t):
36
+ """
37
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
38
+ """
39
+ return torch.log(self.marginal_alpha(t))
40
+
41
+ def marginal_alpha(self, t):
42
+ """
43
+ Compute alpha_t of a given continuous-time label t in [0, T].
44
+ """
45
+ return 1 - t
46
+
47
+ @staticmethod
48
+ def marginal_std(t):
49
+ """
50
+ Compute sigma_t of a given continuous-time label t in [0, T].
51
+ """
52
+ return t
53
+
54
+ def marginal_lambda(self, t):
55
+ """
56
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
57
+ """
58
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
59
+ log_std = torch.log(self.marginal_std(t))
60
+ return log_mean_coeff - log_std
61
+
62
+ @staticmethod
63
+ def inverse_lambda(lamb):
64
+ """
65
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
66
+ """
67
+ return torch.exp(-lamb)
68
+
69
+
70
+ def model_wrapper(
71
+ model,
72
+ noise_schedule,
73
+ model_type="noise",
74
+ model_kwargs={},
75
+ guidance_type="uncond",
76
+ condition=None,
77
+ unconditional_condition=None,
78
+ guidance_scale=1.0,
79
+ interval_guidance=[0, 1.0],
80
+ classifier_fn=None,
81
+ classifier_kwargs={},
82
+ ):
83
+ """Create a wrapper function for the noise prediction model.
84
+
85
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
86
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
87
+
88
+ We support four types of the diffusion model by setting `model_type`:
89
+
90
+ 1. "noise": noise prediction model. (Trained by predicting noise).
91
+
92
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
93
+
94
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
95
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
96
+
97
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
98
+ arXiv preprint arXiv:2202.00512 (2022).
99
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
100
+ arXiv preprint arXiv:2210.02303 (2022).
101
+
102
+ 4. "score": marginal score function. (Trained by denoising score matching).
103
+ Note that the score function and the noise prediction model follows a simple relationship:
104
+ ```
105
+ noise(x_t, t) = -sigma_t * score(x_t, t)
106
+ ```
107
+
108
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
109
+ 1. "uncond": unconditional sampling by DPMs.
110
+ The input `model` has the following format:
111
+ ``
112
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
113
+ ``
114
+
115
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
116
+ The input `model` has the following format:
117
+ ``
118
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
119
+ ``
120
+
121
+ The input `classifier_fn` has the following format:
122
+ ``
123
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
124
+ ``
125
+
126
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
127
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
128
+
129
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
130
+ The input `model` has the following format:
131
+ ``
132
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
133
+ ``
134
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
135
+
136
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
137
+ arXiv preprint arXiv:2207.12598 (2022).
138
+
139
+
140
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
141
+ or continuous-time labels (i.e. epsilon to T).
142
+
143
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
144
+ ``
145
+ def model_fn(x, t_continuous) -> noise:
146
+ t_input = get_model_input_time(t_continuous)
147
+ return noise_pred(model, x, t_input, **model_kwargs)
148
+ ``
149
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
150
+
151
+ ===============================================================
152
+
153
+ Args:
154
+ model: A diffusion model with the corresponding format described above.
155
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
156
+ model_type: A `str`. The parameterization type of the diffusion model.
157
+ "noise" or "x_start" or "v" or "score".
158
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
159
+ guidance_type: A `str`. The type of the guidance for sampling.
160
+ "uncond" or "classifier" or "classifier-free".
161
+ condition: A pytorch tensor. The condition for the guided sampling.
162
+ Only used for "classifier" or "classifier-free" guidance type.
163
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
164
+ Only used for "classifier-free" guidance type.
165
+ guidance_scale: A `float`. The scale for the guided sampling.
166
+ classifier_fn: A classifier function. Only used for the classifier guidance.
167
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
168
+ Returns:
169
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
170
+ """
171
+
172
+ def get_model_input_time(t_continuous):
173
+ """
174
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
175
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
176
+ For continuous-time DPMs, we just use `t_continuous`.
177
+ """
178
+ if noise_schedule.schedule == "discrete":
179
+ return (t_continuous - 1.0 / noise_schedule.total_N) * noise_schedule.total_N
180
+ elif noise_schedule.schedule == "discrete_flow":
181
+ return t_continuous * noise_schedule.total_N
182
+ else:
183
+ return t_continuous
184
+
185
+ def noise_pred_fn(x, t_continuous, cond=None):
186
+ t_input = get_model_input_time(t_continuous)
187
+ if cond is None:
188
+ output = model(x, t_input, **model_kwargs)
189
+ else:
190
+ output = model(x, t_input, cond, **model_kwargs)
191
+ if model_type == "noise":
192
+ return output
193
+ elif model_type == "x_start":
194
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
195
+ return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
196
+ elif model_type == "v":
197
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
198
+ return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
199
+ elif model_type == "score":
200
+ sigma_t = noise_schedule.marginal_std(t_continuous)
201
+ return -expand_dims(sigma_t, x.dim()) * output
202
+ elif model_type == "flow":
203
+ _, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
204
+ try:
205
+ noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output + x
206
+ except:
207
+ noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0] + x
208
+ return noise
209
+
210
+ def cond_grad_fn(x, t_input):
211
+ """
212
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
213
+ """
214
+ with torch.enable_grad():
215
+ x_in = x.detach().requires_grad_(True)
216
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
217
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
218
+
219
+ def model_fn(x, t_continuous):
220
+ """
221
+ The noise predicition model function that is used for DPM-Solver.
222
+ """
223
+ guidance_tp = guidance_type
224
+ if guidance_tp == "uncond":
225
+ return noise_pred_fn(x, t_continuous)
226
+ elif guidance_tp == "classifier":
227
+ assert classifier_fn is not None
228
+ t_input = get_model_input_time(t_continuous)
229
+ cond_grad = cond_grad_fn(x, t_input)
230
+ sigma_t = noise_schedule.marginal_std(t_continuous)
231
+ noise = noise_pred_fn(x, t_continuous)
232
+ return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
233
+ elif guidance_tp == "classifier-free":
234
+ if (
235
+ guidance_scale == 1.0
236
+ or unconditional_condition is None
237
+ or not (interval_guidance[0] < t_continuous[0] < interval_guidance[1])
238
+ ):
239
+ return noise_pred_fn(x, t_continuous, cond=condition)
240
+ else:
241
+ x_in = torch.cat([x] * 2)
242
+ t_in = torch.cat([t_continuous] * 2)
243
+ c_in = torch.cat([unconditional_condition, condition])
244
+ try:
245
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
246
+ except:
247
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in)[0].chunk(2)
248
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
249
+
250
+ assert model_type in ["noise", "x_start", "v", "score", "flow"]
251
+ assert guidance_type in [
252
+ "uncond",
253
+ "classifier",
254
+ "classifier-free",
255
+ ]
256
+ return model_fn
257
+
258
+
259
+ class DPM_Solver:
260
+ def __init__(
261
+ self,
262
+ model_fn,
263
+ noise_schedule,
264
+ algorithm_type="dpmsolver++",
265
+ correcting_x0_fn=None,
266
+ correcting_xt_fn=None,
267
+ thresholding_max_val=1.0,
268
+ dynamic_thresholding_ratio=0.995,
269
+ ):
270
+ """Construct a DPM-Solver.
271
+
272
+ We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
273
+
274
+ We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
275
+ can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
276
+ dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
277
+ DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
278
+ DPMs (such as stable-diffusion).
279
+
280
+ To support advanced algorithms in image-to-image applications, we also support corrector functions for
281
+ both x0 and xt.
282
+
283
+ Args:
284
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
285
+ ``
286
+ def model_fn(x, t_continuous):
287
+ return noise
288
+ ``
289
+ The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
290
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
291
+ algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
292
+ correcting_x0_fn: A `str` or a function with the following format:
293
+ ```
294
+ def correcting_x0_fn(x0, t):
295
+ x0_new = ...
296
+ return x0_new
297
+ ```
298
+ This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
299
+ ```
300
+ x0_pred = data_pred_model(xt, t)
301
+ if correcting_x0_fn is not None:
302
+ x0_pred = correcting_x0_fn(x0_pred, t)
303
+ xt_1 = update(x0_pred, xt, t)
304
+ ```
305
+ If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
306
+ correcting_xt_fn: A function with the following format:
307
+ ```
308
+ def correcting_xt_fn(xt, t, step):
309
+ x_new = ...
310
+ return x_new
311
+ ```
312
+ This function is to correct the intermediate samples xt at each sampling step. e.g.,
313
+ ```
314
+ xt = ...
315
+ xt = correcting_xt_fn(xt, t, step)
316
+ ```
317
+ thresholding_max_val: A `float`. The max value for thresholding.
318
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
319
+ dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
320
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
321
+
322
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
323
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
324
+ with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
325
+ """
326
+ self.model = lambda x, t: model_fn(x, t.expand(x.shape[0]))
327
+ self.noise_schedule = noise_schedule
328
+ assert algorithm_type in ["dpmsolver", "dpmsolver++"]
329
+ self.algorithm_type = algorithm_type
330
+ if correcting_x0_fn == "dynamic_thresholding":
331
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
332
+ else:
333
+ self.correcting_x0_fn = correcting_x0_fn
334
+ self.correcting_xt_fn = correcting_xt_fn
335
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
336
+ self.thresholding_max_val = thresholding_max_val
337
+ self.register_progress_bar()
338
+
339
+ def register_progress_bar(self, progress_fn=None):
340
+ """
341
+ Register a progress bar callback function
342
+
343
+ Args:
344
+ progress_fn: Callback function that takes current step and total steps as parameters
345
+ """
346
+ self.progress_fn = progress_fn if progress_fn is not None else lambda step, total: None
347
+
348
+ def update_progress(self, step, total_steps):
349
+ """
350
+ Update sampling progress
351
+
352
+ Args:
353
+ step: Current step number
354
+ total_steps: Total number of steps
355
+ """
356
+ if hasattr(self, "progress_fn"):
357
+ try:
358
+ self.progress_fn(step / total_steps, desc=f"Generating {step}/{total_steps}")
359
+ except:
360
+ self.progress_fn(step, total_steps)
361
+
362
+ else:
363
+ # If no progress_fn registered, use default empty function
364
+ pass
365
+
366
+ def dynamic_thresholding_fn(self, x0, t):
367
+ """
368
+ The dynamic thresholding method.
369
+ """
370
+ dims = x0.dim()
371
+ p = self.dynamic_thresholding_ratio
372
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
373
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
374
+ x0 = torch.clamp(x0, -s, s) / s
375
+ return x0
376
+
377
+ def noise_prediction_fn(self, x, t):
378
+ """
379
+ Return the noise prediction model.
380
+ """
381
+ return self.model(x, t)
382
+
383
+ def data_prediction_fn(self, x, t):
384
+ """
385
+ Return the data prediction model (with corrector).
386
+ """
387
+ noise = self.noise_prediction_fn(x, t)
388
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
389
+ x0 = (x - sigma_t * noise) / alpha_t
390
+ if self.correcting_x0_fn is not None:
391
+ x0 = self.correcting_x0_fn(x0, t)
392
+ return x0
393
+
394
+ def model_fn(self, x, t):
395
+ """
396
+ Convert the model to the noise prediction model or the data prediction model.
397
+ """
398
+ if self.algorithm_type == "dpmsolver++":
399
+ return self.data_prediction_fn(x, t)
400
+ else:
401
+ return self.noise_prediction_fn(x, t)
402
+
403
+ def get_time_steps(self, skip_type, t_T, t_0, N, device, shift=1.0):
404
+ """Compute the intermediate time steps for sampling.
405
+
406
+ Args:
407
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
408
+ - 'logSNR': uniform logSNR for the time steps.
409
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
410
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
411
+ t_T: A `float`. The starting time of the sampling (default is T).
412
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
413
+ N: A `int`. The total number of the spacing of the time steps.
414
+ device: A torch device.
415
+ Returns:
416
+ A pytorch tensor of the time steps, with the shape (N + 1,).
417
+ """
418
+ if skip_type == "logSNR":
419
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
420
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
421
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
422
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
423
+ elif skip_type == "time_uniform":
424
+ return torch.linspace(t_T, t_0, N + 1).to(device)
425
+ elif skip_type == "time_quadratic":
426
+ t_order = 2
427
+ t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
428
+ return t
429
+ elif skip_type == "time_uniform_flow":
430
+ betas = torch.linspace(t_T, t_0, N + 1).to(device)
431
+ sigmas = 1.0 - betas
432
+ sigmas = (shift * sigmas / (1 + (shift - 1) * sigmas)).flip(dims=[0])
433
+ return sigmas
434
+ else:
435
+ raise ValueError(
436
+ f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
437
+ )
438
+
439
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
440
+ """
441
+ Get the order of each step for sampling by the singlestep DPM-Solver.
442
+
443
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
444
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
445
+ - If order == 1:
446
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
447
+ - If order == 2:
448
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
449
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
450
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
451
+ - If order == 3:
452
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
453
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
454
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
455
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
456
+
457
+ ============================================
458
+ Args:
459
+ order: A `int`. The max order for the solver (2 or 3).
460
+ steps: A `int`. The total number of function evaluations (NFE).
461
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
462
+ - 'logSNR': uniform logSNR for the time steps.
463
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
464
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
465
+ t_T: A `float`. The starting time of the sampling (default is T).
466
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
467
+ device: A torch device.
468
+ Returns:
469
+ orders: A list of the solver order of each step.
470
+ """
471
+ if order == 3:
472
+ K = steps // 3 + 1
473
+ if steps % 3 == 0:
474
+ orders = [3,] * (
475
+ K - 2
476
+ ) + [2, 1]
477
+ elif steps % 3 == 1:
478
+ orders = [3,] * (
479
+ K - 1
480
+ ) + [1]
481
+ else:
482
+ orders = [3,] * (
483
+ K - 1
484
+ ) + [2]
485
+ elif order == 2:
486
+ if steps % 2 == 0:
487
+ K = steps // 2
488
+ orders = [
489
+ 2,
490
+ ] * K
491
+ else:
492
+ K = steps // 2 + 1
493
+ orders = [2,] * (
494
+ K - 1
495
+ ) + [1]
496
+ elif order == 1:
497
+ K = 1
498
+ orders = [
499
+ 1,
500
+ ] * steps
501
+ else:
502
+ raise ValueError("'order' must be '1' or '2' or '3'.")
503
+ if skip_type == "logSNR":
504
+ # To reproduce the results in DPM-Solver paper
505
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
506
+ else:
507
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
508
+ torch.cumsum(
509
+ torch.tensor(
510
+ [
511
+ 0,
512
+ ]
513
+ + orders
514
+ ),
515
+ 0,
516
+ ).to(device)
517
+ ]
518
+ return timesteps_outer, orders
519
+
520
+ def denoise_to_zero_fn(self, x, s):
521
+ """
522
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
523
+ """
524
+ return self.data_prediction_fn(x, s)
525
+
526
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
527
+ """
528
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
529
+
530
+ Args:
531
+ x: A pytorch tensor. The initial value at time `s`.
532
+ s: A pytorch tensor. The starting time, with the shape (1,).
533
+ t: A pytorch tensor. The ending time, with the shape (1,).
534
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
535
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
536
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
537
+ Returns:
538
+ x_t: A pytorch tensor. The approximated solution at time `t`.
539
+ """
540
+ ns = self.noise_schedule
541
+ dims = x.dim()
542
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
543
+ h = lambda_t - lambda_s
544
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
545
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
546
+ alpha_t = torch.exp(log_alpha_t)
547
+
548
+ if self.algorithm_type == "dpmsolver++":
549
+ phi_1 = torch.expm1(-h)
550
+ if model_s is None:
551
+ model_s = self.model_fn(x, s)
552
+ x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
553
+ if return_intermediate:
554
+ return x_t, {"model_s": model_s}
555
+ else:
556
+ return x_t
557
+ else:
558
+ phi_1 = torch.expm1(h)
559
+ if model_s is None:
560
+ model_s = self.model_fn(x, s)
561
+ x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s
562
+ if return_intermediate:
563
+ return x_t, {"model_s": model_s}
564
+ else:
565
+ return x_t
566
+
567
+ def singlestep_dpm_solver_second_update(
568
+ self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver"
569
+ ):
570
+ """
571
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
572
+
573
+ Args:
574
+ x: A pytorch tensor. The initial value at time `s`.
575
+ s: A pytorch tensor. The starting time, with the shape (1,).
576
+ t: A pytorch tensor. The ending time, with the shape (1,).
577
+ r1: A `float`. The hyperparameter of the second-order solver.
578
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
579
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
580
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
581
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
582
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
583
+ Returns:
584
+ x_t: A pytorch tensor. The approximated solution at time `t`.
585
+ """
586
+ if solver_type not in ["dpmsolver", "taylor"]:
587
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
588
+ if r1 is None:
589
+ r1 = 0.5
590
+ ns = self.noise_schedule
591
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
592
+ h = lambda_t - lambda_s
593
+ lambda_s1 = lambda_s + r1 * h
594
+ s1 = ns.inverse_lambda(lambda_s1)
595
+ log_alpha_s, log_alpha_s1, log_alpha_t = (
596
+ ns.marginal_log_mean_coeff(s),
597
+ ns.marginal_log_mean_coeff(s1),
598
+ ns.marginal_log_mean_coeff(t),
599
+ )
600
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
601
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
602
+
603
+ if self.algorithm_type == "dpmsolver++":
604
+ phi_11 = torch.expm1(-r1 * h)
605
+ phi_1 = torch.expm1(-h)
606
+
607
+ if model_s is None:
608
+ model_s = self.model_fn(x, s)
609
+ x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
610
+ model_s1 = self.model_fn(x_s1, s1)
611
+ if solver_type == "dpmsolver":
612
+ x_t = (
613
+ (sigma_t / sigma_s) * x
614
+ - (alpha_t * phi_1) * model_s
615
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
616
+ )
617
+ elif solver_type == "taylor":
618
+ x_t = (
619
+ (sigma_t / sigma_s) * x
620
+ - (alpha_t * phi_1) * model_s
621
+ + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s)
622
+ )
623
+ else:
624
+ phi_11 = torch.expm1(r1 * h)
625
+ phi_1 = torch.expm1(h)
626
+
627
+ if model_s is None:
628
+ model_s = self.model_fn(x, s)
629
+ x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
630
+ model_s1 = self.model_fn(x_s1, s1)
631
+ if solver_type == "dpmsolver":
632
+ x_t = (
633
+ torch.exp(log_alpha_t - log_alpha_s) * x
634
+ - (sigma_t * phi_1) * model_s
635
+ - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
636
+ )
637
+ elif solver_type == "taylor":
638
+ x_t = (
639
+ torch.exp(log_alpha_t - log_alpha_s) * x
640
+ - (sigma_t * phi_1) * model_s
641
+ - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s)
642
+ )
643
+ if return_intermediate:
644
+ return x_t, {"model_s": model_s, "model_s1": model_s1}
645
+ else:
646
+ return x_t
647
+
648
+ def singlestep_dpm_solver_third_update(
649
+ self,
650
+ x,
651
+ s,
652
+ t,
653
+ r1=1.0 / 3.0,
654
+ r2=2.0 / 3.0,
655
+ model_s=None,
656
+ model_s1=None,
657
+ return_intermediate=False,
658
+ solver_type="dpmsolver",
659
+ ):
660
+ """
661
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
662
+
663
+ Args:
664
+ x: A pytorch tensor. The initial value at time `s`.
665
+ s: A pytorch tensor. The starting time, with the shape (1,).
666
+ t: A pytorch tensor. The ending time, with the shape (1,).
667
+ r1: A `float`. The hyperparameter of the third-order solver.
668
+ r2: A `float`. The hyperparameter of the third-order solver.
669
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
670
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
671
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
672
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
673
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
674
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
675
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
676
+ Returns:
677
+ x_t: A pytorch tensor. The approximated solution at time `t`.
678
+ """
679
+ if solver_type not in ["dpmsolver", "taylor"]:
680
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
681
+ if r1 is None:
682
+ r1 = 1.0 / 3.0
683
+ if r2 is None:
684
+ r2 = 2.0 / 3.0
685
+ ns = self.noise_schedule
686
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
687
+ h = lambda_t - lambda_s
688
+ lambda_s1 = lambda_s + r1 * h
689
+ lambda_s2 = lambda_s + r2 * h
690
+ s1 = ns.inverse_lambda(lambda_s1)
691
+ s2 = ns.inverse_lambda(lambda_s2)
692
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
693
+ ns.marginal_log_mean_coeff(s),
694
+ ns.marginal_log_mean_coeff(s1),
695
+ ns.marginal_log_mean_coeff(s2),
696
+ ns.marginal_log_mean_coeff(t),
697
+ )
698
+ sigma_s, sigma_s1, sigma_s2, sigma_t = (
699
+ ns.marginal_std(s),
700
+ ns.marginal_std(s1),
701
+ ns.marginal_std(s2),
702
+ ns.marginal_std(t),
703
+ )
704
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
705
+
706
+ if self.algorithm_type == "dpmsolver++":
707
+ phi_11 = torch.expm1(-r1 * h)
708
+ phi_12 = torch.expm1(-r2 * h)
709
+ phi_1 = torch.expm1(-h)
710
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
711
+ phi_2 = phi_1 / h + 1.0
712
+ phi_3 = phi_2 / h - 0.5
713
+
714
+ if model_s is None:
715
+ model_s = self.model_fn(x, s)
716
+ if model_s1 is None:
717
+ x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
718
+ model_s1 = self.model_fn(x_s1, s1)
719
+ x_s2 = (
720
+ (sigma_s2 / sigma_s) * x
721
+ - (alpha_s2 * phi_12) * model_s
722
+ + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
723
+ )
724
+ model_s2 = self.model_fn(x_s2, s2)
725
+ if solver_type == "dpmsolver":
726
+ x_t = (
727
+ (sigma_t / sigma_s) * x
728
+ - (alpha_t * phi_1) * model_s
729
+ + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
730
+ )
731
+ elif solver_type == "taylor":
732
+ D1_0 = (1.0 / r1) * (model_s1 - model_s)
733
+ D1_1 = (1.0 / r2) * (model_s2 - model_s)
734
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
735
+ D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
736
+ x_t = (
737
+ (sigma_t / sigma_s) * x
738
+ - (alpha_t * phi_1) * model_s
739
+ + (alpha_t * phi_2) * D1
740
+ - (alpha_t * phi_3) * D2
741
+ )
742
+ else:
743
+ phi_11 = torch.expm1(r1 * h)
744
+ phi_12 = torch.expm1(r2 * h)
745
+ phi_1 = torch.expm1(h)
746
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
747
+ phi_2 = phi_1 / h - 1.0
748
+ phi_3 = phi_2 / h - 0.5
749
+
750
+ if model_s is None:
751
+ model_s = self.model_fn(x, s)
752
+ if model_s1 is None:
753
+ x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
754
+ model_s1 = self.model_fn(x_s1, s1)
755
+ x_s2 = (
756
+ (torch.exp(log_alpha_s2 - log_alpha_s)) * x
757
+ - (sigma_s2 * phi_12) * model_s
758
+ - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
759
+ )
760
+ model_s2 = self.model_fn(x_s2, s2)
761
+ if solver_type == "dpmsolver":
762
+ x_t = (
763
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
764
+ - (sigma_t * phi_1) * model_s
765
+ - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
766
+ )
767
+ elif solver_type == "taylor":
768
+ D1_0 = (1.0 / r1) * (model_s1 - model_s)
769
+ D1_1 = (1.0 / r2) * (model_s2 - model_s)
770
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
771
+ D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
772
+ x_t = (
773
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
774
+ - (sigma_t * phi_1) * model_s
775
+ - (sigma_t * phi_2) * D1
776
+ - (sigma_t * phi_3) * D2
777
+ )
778
+
779
+ if return_intermediate:
780
+ return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
781
+ else:
782
+ return x_t
783
+
784
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
785
+ """
786
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
787
+
788
+ Args:
789
+ x: A pytorch tensor. The initial value at time `s`.
790
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
791
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
792
+ t: A pytorch tensor. The ending time, with the shape (1,).
793
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
794
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
795
+ Returns:
796
+ x_t: A pytorch tensor. The approximated solution at time `t`.
797
+ """
798
+ if solver_type not in ["dpmsolver", "taylor"]:
799
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
800
+ ns = self.noise_schedule
801
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
802
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
803
+ lambda_prev_1, lambda_prev_0, lambda_t = (
804
+ ns.marginal_lambda(t_prev_1),
805
+ ns.marginal_lambda(t_prev_0),
806
+ ns.marginal_lambda(t),
807
+ )
808
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
809
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
810
+ alpha_t = torch.exp(log_alpha_t)
811
+
812
+ h_0 = lambda_prev_0 - lambda_prev_1
813
+ h = lambda_t - lambda_prev_0
814
+ r0 = h_0 / h
815
+ D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
816
+ if self.algorithm_type == "dpmsolver++":
817
+ phi_1 = torch.expm1(-h)
818
+ if solver_type == "dpmsolver":
819
+ x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
820
+ elif solver_type == "taylor":
821
+ x_t = (
822
+ (sigma_t / sigma_prev_0) * x
823
+ - (alpha_t * phi_1) * model_prev_0
824
+ + (alpha_t * (phi_1 / h + 1.0)) * D1_0
825
+ )
826
+ else:
827
+ phi_1 = torch.expm1(h)
828
+ if solver_type == "dpmsolver":
829
+ x_t = (
830
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
831
+ - (sigma_t * phi_1) * model_prev_0
832
+ - 0.5 * (sigma_t * phi_1) * D1_0
833
+ )
834
+ elif solver_type == "taylor":
835
+ x_t = (
836
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
837
+ - (sigma_t * phi_1) * model_prev_0
838
+ - (sigma_t * (phi_1 / h - 1.0)) * D1_0
839
+ )
840
+ return x_t
841
+
842
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
843
+ """
844
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
845
+
846
+ Args:
847
+ x: A pytorch tensor. The initial value at time `s`.
848
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
849
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
850
+ t: A pytorch tensor. The ending time, with the shape (1,).
851
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
852
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
853
+ Returns:
854
+ x_t: A pytorch tensor. The approximated solution at time `t`.
855
+ """
856
+ ns = self.noise_schedule
857
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
858
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
859
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
860
+ ns.marginal_lambda(t_prev_2),
861
+ ns.marginal_lambda(t_prev_1),
862
+ ns.marginal_lambda(t_prev_0),
863
+ ns.marginal_lambda(t),
864
+ )
865
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
866
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
867
+ alpha_t = torch.exp(log_alpha_t)
868
+
869
+ h_1 = lambda_prev_1 - lambda_prev_2
870
+ h_0 = lambda_prev_0 - lambda_prev_1
871
+ h = lambda_t - lambda_prev_0
872
+ r0, r1 = h_0 / h, h_1 / h
873
+ D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
874
+ D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2)
875
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
876
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
877
+ if self.algorithm_type == "dpmsolver++":
878
+ phi_1 = torch.expm1(-h)
879
+ phi_2 = phi_1 / h + 1.0
880
+ phi_3 = phi_2 / h - 0.5
881
+ x_t = (
882
+ (sigma_t / sigma_prev_0) * x
883
+ - (alpha_t * phi_1) * model_prev_0
884
+ + (alpha_t * phi_2) * D1
885
+ - (alpha_t * phi_3) * D2
886
+ )
887
+ else:
888
+ phi_1 = torch.expm1(h)
889
+ phi_2 = phi_1 / h - 1.0
890
+ phi_3 = phi_2 / h - 0.5
891
+ x_t = (
892
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
893
+ - (sigma_t * phi_1) * model_prev_0
894
+ - (sigma_t * phi_2) * D1
895
+ - (sigma_t * phi_3) * D2
896
+ )
897
+ return x_t
898
+
899
+ def singlestep_dpm_solver_update(
900
+ self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None
901
+ ):
902
+ """
903
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
904
+
905
+ Args:
906
+ x: A pytorch tensor. The initial value at time `s`.
907
+ s: A pytorch tensor. The starting time, with the shape (1,).
908
+ t: A pytorch tensor. The ending time, with the shape (1,).
909
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
910
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
911
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
912
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
913
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
914
+ r2: A `float`. The hyperparameter of the third-order solver.
915
+ Returns:
916
+ x_t: A pytorch tensor. The approximated solution at time `t`.
917
+ """
918
+ if order == 1:
919
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
920
+ elif order == 2:
921
+ return self.singlestep_dpm_solver_second_update(
922
+ x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1
923
+ )
924
+ elif order == 3:
925
+ return self.singlestep_dpm_solver_third_update(
926
+ x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2
927
+ )
928
+ else:
929
+ raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
930
+
931
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
932
+ """
933
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
934
+
935
+ Args:
936
+ x: A pytorch tensor. The initial value at time `s`.
937
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
938
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
939
+ t: A pytorch tensor. The ending time, with the shape (1,).
940
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
941
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
942
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
943
+ Returns:
944
+ x_t: A pytorch tensor. The approximated solution at time `t`.
945
+ """
946
+ if order == 1:
947
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
948
+ elif order == 2:
949
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
950
+ elif order == 3:
951
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
952
+ else:
953
+ raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
954
+
955
+ def dpm_solver_adaptive(
956
+ self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver"
957
+ ):
958
+ """
959
+ The adaptive step size solver based on singlestep DPM-Solver.
960
+
961
+ Args:
962
+ x: A pytorch tensor. The initial value at time `t_T`.
963
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
964
+ t_T: A `float`. The starting time of the sampling (default is T).
965
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
966
+ h_init: A `float`. The initial step size (for logSNR).
967
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
968
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
969
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
970
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
971
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
972
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
973
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
974
+ Returns:
975
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
976
+
977
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
978
+ """
979
+ ns = self.noise_schedule
980
+ s = t_T * torch.ones((1,)).to(x)
981
+ lambda_s = ns.marginal_lambda(s)
982
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
983
+ h = h_init * torch.ones_like(s).to(x)
984
+ x_prev = x
985
+ nfe = 0
986
+ if order == 2:
987
+ r1 = 0.5
988
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
989
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
990
+ x, s, t, r1=r1, solver_type=solver_type, **kwargs
991
+ )
992
+ elif order == 3:
993
+ r1, r2 = 1.0 / 3.0, 2.0 / 3.0
994
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
995
+ x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
996
+ )
997
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
998
+ x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
999
+ )
1000
+ else:
1001
+ raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}")
1002
+ while torch.abs(s - t_0).mean() > t_err:
1003
+ t = ns.inverse_lambda(lambda_s + h)
1004
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
1005
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1006
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1007
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1008
+ E = norm_fn((x_higher - x_lower) / delta).max()
1009
+ if torch.all(E <= 1.0):
1010
+ x = x_higher
1011
+ s = t
1012
+ x_prev = x_lower
1013
+ lambda_s = ns.marginal_lambda(s)
1014
+ h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s)
1015
+ nfe += order
1016
+ print("adaptive solver nfe", nfe)
1017
+ return x
1018
+
1019
+ def add_noise(self, x, t, noise=None):
1020
+ """
1021
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
1022
+
1023
+ Args:
1024
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1025
+ t: A `torch.Tensor` with shape `(t_size,)`.
1026
+ Returns:
1027
+ xt with shape `(t_size, batch_size, *shape)`.
1028
+ """
1029
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
1030
+ if noise is None:
1031
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1032
+ x = x.reshape((-1, *x.shape))
1033
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1034
+ if t.shape[0] == 1:
1035
+ return xt.squeeze(0)
1036
+ else:
1037
+ return xt
1038
+
1039
+ def inverse(
1040
+ self,
1041
+ x,
1042
+ steps=20,
1043
+ t_start=None,
1044
+ t_end=None,
1045
+ order=2,
1046
+ skip_type="time_uniform",
1047
+ method="multistep",
1048
+ lower_order_final=True,
1049
+ denoise_to_zero=False,
1050
+ solver_type="dpmsolver",
1051
+ atol=0.0078,
1052
+ rtol=0.05,
1053
+ return_intermediate=False,
1054
+ ):
1055
+ """
1056
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1057
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1058
+ """
1059
+ t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
1060
+ t_T = self.noise_schedule.T if t_end is None else t_end
1061
+ assert (
1062
+ t_0 > 0 and t_T > 0
1063
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1064
+ return self.sample(
1065
+ x,
1066
+ steps=steps,
1067
+ t_start=t_0,
1068
+ t_end=t_T,
1069
+ order=order,
1070
+ skip_type=skip_type,
1071
+ method=method,
1072
+ lower_order_final=lower_order_final,
1073
+ denoise_to_zero=denoise_to_zero,
1074
+ solver_type=solver_type,
1075
+ atol=atol,
1076
+ rtol=rtol,
1077
+ return_intermediate=return_intermediate,
1078
+ )
1079
+
1080
+ def sample(
1081
+ self,
1082
+ x,
1083
+ steps=20,
1084
+ t_start=None,
1085
+ t_end=None,
1086
+ order=2,
1087
+ skip_type="time_uniform",
1088
+ method="multistep",
1089
+ lower_order_final=True,
1090
+ denoise_to_zero=False,
1091
+ solver_type="dpmsolver",
1092
+ atol=0.0078,
1093
+ rtol=0.05,
1094
+ return_intermediate=False,
1095
+ flow_shift=1.0,
1096
+ ):
1097
+ """
1098
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1099
+
1100
+ =====================================================
1101
+
1102
+ We support the following algorithms for both noise prediction model and data prediction model:
1103
+ - 'singlestep':
1104
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1105
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1106
+ The total number of function evaluations (NFE) == `steps`.
1107
+ Given a fixed NFE == `steps`, the sampling procedure is:
1108
+ - If `order` == 1:
1109
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1110
+ - If `order` == 2:
1111
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1112
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1113
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1114
+ - If `order` == 3:
1115
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1116
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1117
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1118
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1119
+ - 'multistep':
1120
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1121
+ We initialize the first `order` values by lower order multistep solvers.
1122
+ Given a fixed NFE == `steps`, the sampling procedure is:
1123
+ Denote K = steps.
1124
+ - If `order` == 1:
1125
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1126
+ - If `order` == 2:
1127
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1128
+ - If `order` == 3:
1129
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1130
+ - 'singlestep_fixed':
1131
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1132
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1133
+ - 'adaptive':
1134
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1135
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1136
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1137
+ (NFE) and the sample quality.
1138
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1139
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1140
+
1141
+ =====================================================
1142
+
1143
+ Some advices for choosing the algorithm:
1144
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1145
+ Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1146
+ e.g., DPM-Solver:
1147
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1148
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1149
+ skip_type='time_uniform', method='singlestep')
1150
+ e.g., DPM-Solver++:
1151
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1152
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1153
+ skip_type='time_uniform', method='singlestep')
1154
+ - For **guided sampling with large guidance scale** by DPMs:
1155
+ Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1156
+ e.g.
1157
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1158
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1159
+ skip_type='time_uniform', method='multistep')
1160
+
1161
+ We support three types of `skip_type`:
1162
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1163
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1164
+ - 'time_quadratic': quadratic time for the time steps.
1165
+
1166
+ =====================================================
1167
+ Args:
1168
+ x: A pytorch tensor. The initial value at time `t_start`
1169
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1170
+ steps: A `int`. The total number of function evaluations (NFE).
1171
+ t_start: A `float`. The starting time of the sampling.
1172
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1173
+ t_end: A `float`. The ending time of the sampling.
1174
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1175
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1176
+ For discrete-time DPMs:
1177
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1178
+ For continuous-time DPMs:
1179
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1180
+ order: A `int`. The order of DPM-Solver.
1181
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1182
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1183
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1184
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1185
+
1186
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1187
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1188
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1189
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1190
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1191
+ it for high-resolutional images.
1192
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1193
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1194
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1195
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1196
+ solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1197
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1198
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1199
+ return_intermediate: A `bool`. Whether to save the xt at each step.
1200
+ When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1201
+ Returns:
1202
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1203
+
1204
+ """
1205
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
1206
+ t_T = self.noise_schedule.T if t_start is None else t_start
1207
+ assert (
1208
+ t_0 > 0 and t_T > 0
1209
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1210
+ if return_intermediate:
1211
+ assert method in [
1212
+ "multistep",
1213
+ "singlestep",
1214
+ "singlestep_fixed",
1215
+ ], "Cannot use adaptive solver when saving intermediate values"
1216
+ if self.correcting_xt_fn is not None:
1217
+ assert method in [
1218
+ "multistep",
1219
+ "singlestep",
1220
+ "singlestep_fixed",
1221
+ ], "Cannot use adaptive solver when correcting_xt_fn is not None"
1222
+ device = x.device
1223
+ intermediates = []
1224
+ with torch.no_grad():
1225
+ if method == "adaptive":
1226
+ x = self.dpm_solver_adaptive(
1227
+ x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type
1228
+ )
1229
+ elif method == "multistep":
1230
+ assert steps >= order
1231
+ timesteps = self.get_time_steps(
1232
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device, shift=flow_shift
1233
+ )
1234
+ assert timesteps.shape[0] - 1 == steps
1235
+ # Init the initial values.
1236
+ step = 0
1237
+ t = timesteps[step]
1238
+ t_prev_list = [t]
1239
+ model_prev_list = [self.model_fn(x, t)]
1240
+ if self.correcting_xt_fn is not None:
1241
+ x = self.correcting_xt_fn(x, t, step)
1242
+ if return_intermediate:
1243
+ intermediates.append(x)
1244
+ self.update_progress(step + 1, len(timesteps))
1245
+ # Init the first `order` values by lower order multistep DPM-Solver.
1246
+ for step in range(1, order):
1247
+ t = timesteps[step]
1248
+ x = self.multistep_dpm_solver_update(
1249
+ x, model_prev_list, t_prev_list, t, step, solver_type=solver_type
1250
+ )
1251
+ if self.correcting_xt_fn is not None:
1252
+ x = self.correcting_xt_fn(x, t, step)
1253
+ if return_intermediate:
1254
+ intermediates.append(x)
1255
+ t_prev_list.append(t)
1256
+ model_prev_list.append(self.model_fn(x, t))
1257
+ # update progress bar
1258
+ self.update_progress(step + 1, len(timesteps))
1259
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1260
+ for step in tqdm(range(order, steps + 1), disable=os.getenv("DPM_TQDM", "False") == "True"):
1261
+ t = timesteps[step]
1262
+ # We only use lower order for steps < 10
1263
+ # if lower_order_final and steps < 10:
1264
+ if lower_order_final: # recommended by Shuchen Xue
1265
+ step_order = min(order, steps + 1 - step)
1266
+ else:
1267
+ step_order = order
1268
+ x = self.multistep_dpm_solver_update(
1269
+ x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type
1270
+ )
1271
+ if self.correcting_xt_fn is not None:
1272
+ x = self.correcting_xt_fn(x, t, step)
1273
+ if return_intermediate:
1274
+ intermediates.append(x)
1275
+ for i in range(order - 1):
1276
+ t_prev_list[i] = t_prev_list[i + 1]
1277
+ model_prev_list[i] = model_prev_list[i + 1]
1278
+ t_prev_list[-1] = t
1279
+ # We do not need to evaluate the final model value.
1280
+ if step < steps:
1281
+ model_prev_list[-1] = self.model_fn(x, t)
1282
+ # update progress bar
1283
+ self.update_progress(step + 1, len(timesteps))
1284
+ elif method in ["singlestep", "singlestep_fixed"]:
1285
+ if method == "singlestep":
1286
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(
1287
+ steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device
1288
+ )
1289
+ elif method == "singlestep_fixed":
1290
+ K = steps // order
1291
+ orders = [
1292
+ order,
1293
+ ] * K
1294
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1295
+ for step, order in enumerate(orders):
1296
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
1297
+ timesteps_inner = self.get_time_steps(
1298
+ skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device
1299
+ )
1300
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1301
+ h = lambda_inner[-1] - lambda_inner[0]
1302
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1303
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1304
+ x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
1305
+ if self.correcting_xt_fn is not None:
1306
+ x = self.correcting_xt_fn(x, t, step)
1307
+ if return_intermediate:
1308
+ intermediates.append(x)
1309
+ self.update_progress(step + 1, len(timesteps_outer))
1310
+ else:
1311
+ raise ValueError(f"Got wrong method {method}")
1312
+ if denoise_to_zero:
1313
+ t = torch.ones((1,)).to(device) * t_0
1314
+ x = self.denoise_to_zero_fn(x, t)
1315
+ if self.correcting_xt_fn is not None:
1316
+ x = self.correcting_xt_fn(x, t, step + 1)
1317
+ if return_intermediate:
1318
+ intermediates.append(x)
1319
+ if return_intermediate:
1320
+ return x, intermediates
1321
+ else:
1322
+ return x
1323
+
1324
+
1325
+ #############################################################
1326
+ # other utility functions
1327
+ #############################################################
1328
+
1329
+
1330
+ def interpolate_fn(x, xp, yp):
1331
+ """
1332
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1333
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1334
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1335
+
1336
+ Args:
1337
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1338
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1339
+ yp: PyTorch tensor with shape [C, K].
1340
+ Returns:
1341
+ The function values f(x), with shape [N, C].
1342
+ """
1343
+ N, K = x.shape[0], xp.shape[1]
1344
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1345
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1346
+ x_idx = torch.argmin(x_indices, dim=2)
1347
+ cand_start_idx = x_idx - 1
1348
+ start_idx = torch.where(
1349
+ torch.eq(x_idx, 0),
1350
+ torch.tensor(1, device=x.device),
1351
+ torch.where(
1352
+ torch.eq(x_idx, K),
1353
+ torch.tensor(K - 2, device=x.device),
1354
+ cand_start_idx,
1355
+ ),
1356
+ )
1357
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1358
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1359
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1360
+ start_idx2 = torch.where(
1361
+ torch.eq(x_idx, 0),
1362
+ torch.tensor(0, device=x.device),
1363
+ torch.where(
1364
+ torch.eq(x_idx, K),
1365
+ torch.tensor(K - 2, device=x.device),
1366
+ cand_start_idx,
1367
+ ),
1368
+ )
1369
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1370
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1371
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1372
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1373
+ return cand
1374
+
1375
+
1376
+ def expand_dims(v, dims):
1377
+ """
1378
+ Expand the tensor `v` to the dim `dims`.
1379
+
1380
+ Args:
1381
+ `v`: a PyTorch tensor with shape [N].
1382
+ `dim`: a `int`.
1383
+ Returns:
1384
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1385
+ """
1386
+ return v[(...,) + (None,) * (dims - 1)]
transport/integrators.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ from torchdiffeq import odeint
3
+ from .utils import time_shift, get_lin_function
4
+
5
+ class sde:
6
+ """SDE solver class"""
7
+
8
+ def __init__(
9
+ self,
10
+ drift,
11
+ diffusion,
12
+ *,
13
+ t0,
14
+ t1,
15
+ num_steps,
16
+ sampler_type,
17
+ ):
18
+ assert t0 < t1, "SDE sampler has to be in forward time"
19
+
20
+ self.num_timesteps = num_steps
21
+ self.t = th.linspace(t0, t1, num_steps)
22
+ self.dt = self.t[1] - self.t[0]
23
+ self.drift = drift
24
+ self.diffusion = diffusion
25
+ self.sampler_type = sampler_type
26
+
27
+ def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
28
+ w_cur = th.randn(x.size()).to(x)
29
+ t = th.ones(x.size(0)).to(x) * t
30
+ dw = w_cur * th.sqrt(self.dt)
31
+ drift = self.drift(x, t, model, **model_kwargs)
32
+ diffusion = self.diffusion(x, t)
33
+ mean_x = x + drift * self.dt
34
+ x = mean_x + th.sqrt(2 * diffusion) * dw
35
+ return x, mean_x
36
+
37
+ def __Heun_step(self, x, _, t, model, **model_kwargs):
38
+ w_cur = th.randn(x.size()).to(x)
39
+ dw = w_cur * th.sqrt(self.dt)
40
+ t_cur = th.ones(x.size(0)).to(x) * t
41
+ diffusion = self.diffusion(x, t_cur)
42
+ xhat = x + th.sqrt(2 * diffusion) * dw
43
+ K1 = self.drift(xhat, t_cur, model, **model_kwargs)
44
+ xp = xhat + self.dt * K1
45
+ K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
46
+ return (
47
+ xhat + 0.5 * self.dt * (K1 + K2),
48
+ xhat,
49
+ ) # at last time point we do not perform the heun step
50
+
51
+ def __forward_fn(self):
52
+ """TODO: generalize here by adding all private functions ending with steps to it"""
53
+ sampler_dict = {
54
+ "Euler": self.__Euler_Maruyama_step,
55
+ "Heun": self.__Heun_step,
56
+ }
57
+
58
+ try:
59
+ sampler = sampler_dict[self.sampler_type]
60
+ except:
61
+ raise NotImplementedError("Smapler type not implemented.")
62
+
63
+ return sampler
64
+
65
+ def sample(self, init, model, **model_kwargs):
66
+ """forward loop of sde"""
67
+ x = init
68
+ mean_x = init
69
+ samples = []
70
+ sampler = self.__forward_fn()
71
+ for ti in self.t[:-1]:
72
+ with th.no_grad():
73
+ x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
74
+ samples.append(x)
75
+
76
+ return samples
77
+
78
+
79
+ class ode:
80
+ """ODE solver class"""
81
+
82
+ def __init__(
83
+ self,
84
+ drift,
85
+ *,
86
+ t0,
87
+ t1,
88
+ sampler_type,
89
+ num_steps,
90
+ atol,
91
+ rtol,
92
+ do_shift=False,
93
+ time_shifting_factor=None,
94
+ ):
95
+ assert t0 < t1, "ODE sampler has to be in forward time"
96
+
97
+ self.drift = drift
98
+ self.do_shift = do_shift
99
+ self.t = th.linspace(t0, t1, num_steps)
100
+ if time_shifting_factor:
101
+ self.t = self.t / (self.t + time_shifting_factor - time_shifting_factor * self.t)
102
+ self.atol = atol
103
+ self.rtol = rtol
104
+ self.sampler_type = sampler_type
105
+
106
+ def sample(self, x, model, **model_kwargs):
107
+ x = x.float()
108
+ device = x[0].device if isinstance(x, tuple) else x.device
109
+
110
+ def _fn(t, x):
111
+ t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
112
+ model_output = self.drift(x, t, model, **model_kwargs).float()
113
+ return model_output
114
+
115
+ t = self.t.to(device)
116
+ if self.do_shift:
117
+ mu = get_lin_function(y1=0.5, y2=1.15)(x.shape[1])
118
+ t = time_shift(mu, 1.0, t)
119
+ atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
120
+ rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
121
+ samples = odeint(_fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol)
122
+ return samples
transport/path.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+
4
+
5
+ def expand_t_like_x(t, x):
6
+ """Function to reshape time t to broadcastable dimension of x
7
+ Args:
8
+ t: [batch_dim,], time vector
9
+ x: [batch_dim,...], data point
10
+ """
11
+ dims = [1] * len(x[0].size())
12
+ t = t.view(t.size(0), *dims)
13
+ return t
14
+
15
+
16
+ #################### Coupling Plans ####################
17
+
18
+
19
+ class ICPlan:
20
+ """Linear Coupling Plan"""
21
+
22
+ def __init__(self, sigma=0.0):
23
+ self.sigma = sigma
24
+
25
+ def compute_alpha_t(self, t):
26
+ """Compute the data coefficient along the path"""
27
+ return t, 1
28
+
29
+ def compute_sigma_t(self, t):
30
+ """Compute the noise coefficient along the path"""
31
+ return 1 - t, -1
32
+
33
+ def compute_d_alpha_alpha_ratio_t(self, t):
34
+ """Compute the ratio between d_alpha and alpha"""
35
+ return 1 / t
36
+
37
+ def compute_drift(self, x, t):
38
+ """We always output sde according to score parametrization;"""
39
+ t = expand_t_like_x(t, x)
40
+ alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
41
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
42
+ drift = alpha_ratio * x
43
+ diffusion = alpha_ratio * (sigma_t**2) - sigma_t * d_sigma_t
44
+
45
+ return -drift, diffusion
46
+
47
+ def compute_diffusion(self, x, t, form="constant", norm=1.0):
48
+ """Compute the diffusion term of the SDE
49
+ Args:
50
+ x: [batch_dim, ...], data point
51
+ t: [batch_dim,], time vector
52
+ form: str, form of the diffusion term
53
+ norm: float, norm of the diffusion term
54
+ """
55
+ t = expand_t_like_x(t, x)
56
+ choices = {
57
+ "constant": norm,
58
+ "SBDM": norm * self.compute_drift(x, t)[1],
59
+ "sigma": norm * self.compute_sigma_t(t)[0],
60
+ "linear": norm * (1 - t),
61
+ "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
62
+ "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
63
+ }
64
+
65
+ try:
66
+ diffusion = choices[form]
67
+ except KeyError:
68
+ raise NotImplementedError(f"Diffusion form {form} not implemented")
69
+
70
+ return diffusion
71
+
72
+ def get_score_from_velocity(self, velocity, x, t):
73
+ """Wrapper function: transfrom velocity prediction model to score
74
+ Args:
75
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
76
+ x: [batch_dim, ...] shaped tensor; x_t data point
77
+ t: [batch_dim,] time tensor
78
+ """
79
+ t = expand_t_like_x(t, x)
80
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
81
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
82
+ mean = x
83
+ reverse_alpha_ratio = alpha_t / d_alpha_t
84
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
85
+ score = (reverse_alpha_ratio * velocity - mean) / var
86
+ return score
87
+
88
+ def get_noise_from_velocity(self, velocity, x, t):
89
+ """Wrapper function: transfrom velocity prediction model to denoiser
90
+ Args:
91
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
92
+ x: [batch_dim, ...] shaped tensor; x_t data point
93
+ t: [batch_dim,] time tensor
94
+ """
95
+ t = expand_t_like_x(t, x)
96
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
97
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
98
+ mean = x
99
+ reverse_alpha_ratio = alpha_t / d_alpha_t
100
+ var = reverse_alpha_ratio * d_sigma_t - sigma_t
101
+ noise = (reverse_alpha_ratio * velocity - mean) / var
102
+ return noise
103
+
104
+ def get_velocity_from_score(self, score, x, t):
105
+ """Wrapper function: transfrom score prediction model to velocity
106
+ Args:
107
+ score: [batch_dim, ...] shaped tensor; score model output
108
+ x: [batch_dim, ...] shaped tensor; x_t data point
109
+ t: [batch_dim,] time tensor
110
+ """
111
+ t = expand_t_like_x(t, x)
112
+ drift, var = self.compute_drift(x, t)
113
+ velocity = var * score - drift
114
+ return velocity
115
+
116
+ def compute_mu_t(self, t, x0, x1):
117
+ """Compute the mean of time-dependent density p_t"""
118
+ t = expand_t_like_x(t, x1)
119
+ alpha_t, _ = self.compute_alpha_t(t)
120
+ sigma_t, _ = self.compute_sigma_t(t)
121
+ if isinstance(x1, (list, tuple)):
122
+ return [alpha_t[i] * x1[i] + sigma_t[i] * x0[i] for i in range(len(x1))]
123
+ else:
124
+ return alpha_t * x1 + sigma_t * x0
125
+
126
+ def compute_xt(self, t, x0, x1):
127
+ """Sample xt from time-dependent density p_t; rng is required"""
128
+ xt = self.compute_mu_t(t, x0, x1)
129
+ return xt
130
+
131
+ def compute_ut(self, t, x0, x1, xt):
132
+ """Compute the vector field corresponding to p_t"""
133
+ t = expand_t_like_x(t, x1)
134
+ _, d_alpha_t = self.compute_alpha_t(t)
135
+ _, d_sigma_t = self.compute_sigma_t(t)
136
+ if isinstance(x1, (list, tuple)):
137
+ return [d_alpha_t * x1[i] + d_sigma_t * x0[i] for i in range(len(x1))]
138
+ else:
139
+ return d_alpha_t * x1 + d_sigma_t * x0
140
+
141
+ def plan(self, t, x0, x1):
142
+ xt = self.compute_xt(t, x0, x1)
143
+ ut = self.compute_ut(t, x0, x1, xt)
144
+ return t, xt, ut
145
+
146
+
147
+ class VPCPlan(ICPlan):
148
+ """class for VP path flow matching"""
149
+
150
+ def __init__(self, sigma_min=0.1, sigma_max=20.0):
151
+ self.sigma_min = sigma_min
152
+ self.sigma_max = sigma_max
153
+ self.log_mean_coeff = (
154
+ lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
155
+ )
156
+ self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
157
+
158
+ def compute_alpha_t(self, t):
159
+ """Compute coefficient of x1"""
160
+ alpha_t = self.log_mean_coeff(t)
161
+ alpha_t = th.exp(alpha_t)
162
+ d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
163
+ return alpha_t, d_alpha_t
164
+
165
+ def compute_sigma_t(self, t):
166
+ """Compute coefficient of x0"""
167
+ p_sigma_t = 2 * self.log_mean_coeff(t)
168
+ sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
169
+ d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
170
+ return sigma_t, d_sigma_t
171
+
172
+ def compute_d_alpha_alpha_ratio_t(self, t):
173
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
174
+ return self.d_log_mean_coeff(t)
175
+
176
+ def compute_drift(self, x, t):
177
+ """Compute the drift term of the SDE"""
178
+ t = expand_t_like_x(t, x)
179
+ beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
180
+ return -0.5 * beta_t * x, beta_t / 2
181
+
182
+
183
+ class GVPCPlan(ICPlan):
184
+ def __init__(self, sigma=0.0):
185
+ super().__init__(sigma)
186
+
187
+ def compute_alpha_t(self, t):
188
+ """Compute coefficient of x1"""
189
+ alpha_t = th.sin(t * np.pi / 2)
190
+ d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
191
+ return alpha_t, d_alpha_t
192
+
193
+ def compute_sigma_t(self, t):
194
+ """Compute coefficient of x0"""
195
+ sigma_t = th.cos(t * np.pi / 2)
196
+ d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
197
+ return sigma_t, d_sigma_t
198
+
199
+ def compute_d_alpha_alpha_ratio_t(self, t):
200
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
201
+ return np.pi / (2 * th.tan(t * np.pi / 2))
transport/transport.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import math
3
+ from typing import Callable
4
+
5
+ import numpy as np
6
+ import torch as th
7
+
8
+ from . import path
9
+ from .integrators import ode, sde
10
+ from .utils import mean_flat, expand_dims
11
+ from .dpm_solver import NoiseScheduleFlow, model_wrapper, DPM_Solver
12
+
13
+
14
+ class ModelType(enum.Enum):
15
+ """
16
+ Which type of output the model predicts.
17
+ """
18
+
19
+ NOISE = enum.auto() # the model predicts epsilon
20
+ SCORE = enum.auto() # the model predicts \nabla \log p(x)
21
+ VELOCITY = enum.auto() # the model predicts v(x)
22
+
23
+
24
+ class PathType(enum.Enum):
25
+ """
26
+ Which type of path to use.
27
+ """
28
+
29
+ LINEAR = enum.auto()
30
+ GVP = enum.auto()
31
+ VP = enum.auto()
32
+
33
+
34
+ class WeightType(enum.Enum):
35
+ """
36
+ Which type of weighting to use.
37
+ """
38
+
39
+ NONE = enum.auto()
40
+ VELOCITY = enum.auto()
41
+ LIKELIHOOD = enum.auto()
42
+
43
+
44
+ class Transport:
45
+ def __init__(self, *, model_type, path_type, loss_type, train_eps, sample_eps, snr_type, do_shift, seq_len):
46
+ path_options = {
47
+ PathType.LINEAR: path.ICPlan,
48
+ PathType.GVP: path.GVPCPlan,
49
+ PathType.VP: path.VPCPlan,
50
+ }
51
+
52
+ self.loss_type = loss_type
53
+ self.model_type = model_type
54
+ self.path_sampler = path_options[path_type]()
55
+ self.train_eps = train_eps
56
+ self.sample_eps = sample_eps
57
+
58
+ self.snr_type = snr_type
59
+ self.do_shift = do_shift
60
+ self.seq_len = seq_len
61
+
62
+ def prior_logp(self, z):
63
+ """
64
+ Standard multivariate normal prior
65
+ Assume z is batched
66
+ """
67
+ shape = th.tensor(z.size())
68
+ N = th.prod(shape[1:])
69
+ _fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
70
+ return th.vmap(_fn)(z)
71
+
72
+ def check_interval(
73
+ self,
74
+ train_eps,
75
+ sample_eps,
76
+ *,
77
+ diffusion_form="SBDM",
78
+ sde=False,
79
+ reverse=False,
80
+ eval=False,
81
+ last_step_size=0.0,
82
+ ):
83
+ t0 = 0
84
+ t1 = 1
85
+ eps = train_eps if not eval else sample_eps
86
+ if type(self.path_sampler) in [path.VPCPlan]:
87
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
88
+
89
+ elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and (
90
+ self.model_type != ModelType.VELOCITY or sde
91
+ ): # avoid numerical issue by taking a first semi-implicit step
92
+ t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
93
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
94
+
95
+ if reverse:
96
+ t0, t1 = 1 - t0, 1 - t1
97
+
98
+ return t0, t1
99
+
100
+ def sample(self, x1):
101
+ """Sampling x0 & t based on shape of x1 (if needed)
102
+ Args:
103
+ x1 - data point; [batch, *dim]
104
+ """
105
+ if isinstance(x1, (list, tuple)):
106
+ x0 = [th.randn_like(img_start) for img_start in x1]
107
+ else:
108
+ x0 = th.randn_like(x1)
109
+ t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
110
+
111
+ if self.snr_type.startswith("uniform"):
112
+ assert t0 == 0.0 and t1 == 1.0, "not implemented."
113
+ if "_" in self.snr_type:
114
+ _, t0, t1 = self.snr_type.split("_")
115
+ t0, t1 = float(t0), float(t1)
116
+ t = th.rand((len(x1),)) * (t1 - t0) + t0
117
+ elif self.snr_type == "lognorm":
118
+ u = th.normal(mean=0.0, std=1.0, size=(len(x1),))
119
+ t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0
120
+ else:
121
+ raise NotImplementedError("Not implemented snr_type %s" % self.snr_type)
122
+
123
+ if self.do_shift:
124
+ base_shift: float = 0.5
125
+ max_shift: float = 1.15
126
+ mu = self.get_lin_function(y1=base_shift, y2=max_shift)(self.seq_len)
127
+ t = self.time_shift(mu, 1.0, t)
128
+ t = t.to(x1[0])
129
+ return t, x0, x1
130
+
131
+ def time_shift(self, mu: float, sigma: float, t: th.Tensor):
132
+ # the following implementation was original for t=0: clean / t=1: noise
133
+ # Since we adopt the reverse, the 1-t operations are needed
134
+ t = 1 - t
135
+ t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
136
+ t = 1 - t
137
+ return t
138
+
139
+ def get_lin_function(
140
+ self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
141
+ ) -> Callable[[float], float]:
142
+ m = (y2 - y1) / (x2 - x1)
143
+ b = y1 - m * x1
144
+ return lambda x: m * x + b
145
+
146
+ def training_losses(self, model, x1, model_kwargs=None):
147
+ """Loss for training the score model
148
+ Args:
149
+ - model: backbone model; could be score, noise, or velocity
150
+ - x1: datapoint
151
+ - model_kwargs: additional arguments for the model
152
+ """
153
+ if model_kwargs == None:
154
+ model_kwargs = {}
155
+ t, x0, x1 = self.sample(x1)
156
+ t, xt, ut = self.path_sampler.plan(t, x0, x1)
157
+ if "cond" in model_kwargs:
158
+ conds = model_kwargs.pop("cond")
159
+ xt = [th.cat([x, cond], dim=0) if cond is not None else x for x, cond in zip(xt, conds)]
160
+ model_output = model(xt, t, **model_kwargs)
161
+ B = len(x0)
162
+
163
+ terms = {}
164
+ # terms['pred'] = model_output
165
+ if self.model_type == ModelType.VELOCITY:
166
+ if isinstance(x1, (list, tuple)):
167
+ assert len(model_output) == len(ut) == len(x1)
168
+ for i in range(B):
169
+ assert (
170
+ model_output[i].shape == ut[i].shape == x1[i].shape
171
+ ), f"{model_output[i].shape} {ut[i].shape} {x1[i].shape}"
172
+ terms["task_loss"] = th.stack(
173
+ [((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
174
+ dim=0,
175
+ )
176
+ else:
177
+ terms["task_loss"] = mean_flat(((model_output - ut) ** 2))
178
+ else:
179
+ raise NotImplementedError
180
+
181
+ terms["loss"] = terms["task_loss"]
182
+ terms["task_loss"] = terms["task_loss"].clone().detach()
183
+ terms["t"] = t
184
+ return terms
185
+
186
+ def get_drift(self):
187
+ """member function for obtaining the drift of the probability flow ODE"""
188
+
189
+ def score_ode(x, t, model, **model_kwargs):
190
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
191
+ model_output = model(x, t, **model_kwargs)
192
+ return -drift_mean + drift_var * model_output # by change of variable
193
+
194
+ def noise_ode(x, t, model, **model_kwargs):
195
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
196
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
197
+ model_output = model(x, t, **model_kwargs)
198
+ score = model_output / -sigma_t
199
+ return -drift_mean + drift_var * score
200
+
201
+ def velocity_ode(x, t, model, **model_kwargs):
202
+ model_output = model(x, t, **model_kwargs)
203
+ return model_output
204
+
205
+ if self.model_type == ModelType.NOISE:
206
+ drift_fn = noise_ode
207
+ elif self.model_type == ModelType.SCORE:
208
+ drift_fn = score_ode
209
+ else:
210
+ drift_fn = velocity_ode
211
+
212
+ def body_fn(x, t, model, **model_kwargs):
213
+ model_output = drift_fn(x, t, model, **model_kwargs)
214
+ assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
215
+ return model_output
216
+
217
+ return body_fn
218
+
219
+ def get_score(
220
+ self,
221
+ ):
222
+ """member function for obtaining score of
223
+ x_t = alpha_t * x + sigma_t * eps"""
224
+ if self.model_type == ModelType.NOISE:
225
+ score_fn = (
226
+ lambda x, t, model, **kwargs: model(x, t, **kwargs)
227
+ / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
228
+ )
229
+ elif self.model_type == ModelType.SCORE:
230
+ score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
231
+ elif self.model_type == ModelType.VELOCITY:
232
+ score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(
233
+ model(x, t, **kwargs), x, t
234
+ )
235
+ else:
236
+ raise NotImplementedError()
237
+
238
+ return score_fn
239
+
240
+
241
+ class Sampler:
242
+ """Sampler class for the transport model"""
243
+
244
+ def __init__(
245
+ self,
246
+ transport,
247
+ ):
248
+ """Constructor for a general sampler; supporting different sampling methods
249
+ Args:
250
+ - transport: an tranport object specify model prediction & interpolant type
251
+ """
252
+
253
+ self.transport = transport
254
+ self.drift = self.transport.get_drift()
255
+ self.score = self.transport.get_score()
256
+
257
+ def __get_sde_diffusion_and_drift(
258
+ self,
259
+ *,
260
+ diffusion_form="SBDM",
261
+ diffusion_norm=1.0,
262
+ ):
263
+ def diffusion_fn(x, t):
264
+ diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
265
+ return diffusion
266
+
267
+ sde_drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(
268
+ x, t, model, **kwargs
269
+ )
270
+
271
+ sde_diffusion = diffusion_fn
272
+
273
+ return sde_drift, sde_diffusion
274
+
275
+ def __get_last_step(
276
+ self,
277
+ sde_drift,
278
+ *,
279
+ last_step,
280
+ last_step_size,
281
+ ):
282
+ """Get the last step function of the SDE solver"""
283
+
284
+ if last_step is None:
285
+ last_step_fn = lambda x, t, model, **model_kwargs: x
286
+ elif last_step == "Mean":
287
+ last_step_fn = (
288
+ lambda x, t, model, **model_kwargs: x + sde_drift(x, t, model, **model_kwargs) * last_step_size
289
+ )
290
+ elif last_step == "Tweedie":
291
+ alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
292
+ sigma = self.transport.path_sampler.compute_sigma_t
293
+ last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][
294
+ 0
295
+ ] * self.score(x, t, model, **model_kwargs)
296
+ elif last_step == "Euler":
297
+ last_step_fn = (
298
+ lambda x, t, model, **model_kwargs: x + self.drift(x, t, model, **model_kwargs) * last_step_size
299
+ )
300
+ else:
301
+ raise NotImplementedError()
302
+
303
+ return last_step_fn
304
+
305
+ def sample_sde(
306
+ self,
307
+ *,
308
+ sampling_method="Euler",
309
+ diffusion_form="SBDM",
310
+ diffusion_norm=1.0,
311
+ last_step="Mean",
312
+ last_step_size=0.04,
313
+ num_steps=250,
314
+ ):
315
+ """returns a sampling function with given SDE settings
316
+ Args:
317
+ - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
318
+ - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
319
+ - diffusion_norm: function magnitude of diffusion coefficient; default to 1
320
+ - last_step: type of the last step; default to identity
321
+ - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
322
+ - num_steps: total integration step of SDE
323
+ """
324
+
325
+ if last_step is None:
326
+ last_step_size = 0.0
327
+
328
+ sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
329
+ diffusion_form=diffusion_form,
330
+ diffusion_norm=diffusion_norm,
331
+ )
332
+
333
+ t0, t1 = self.transport.check_interval(
334
+ self.transport.train_eps,
335
+ self.transport.sample_eps,
336
+ diffusion_form=diffusion_form,
337
+ sde=True,
338
+ eval=True,
339
+ reverse=False,
340
+ last_step_size=last_step_size,
341
+ )
342
+
343
+ _sde = sde(
344
+ sde_drift,
345
+ sde_diffusion,
346
+ t0=t0,
347
+ t1=t1,
348
+ num_steps=num_steps,
349
+ sampler_type=sampling_method,
350
+ )
351
+
352
+ last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
353
+
354
+ def _sample(init, model, **model_kwargs):
355
+ xs = _sde.sample(init, model, **model_kwargs)
356
+ ts = th.ones(init.size(0), device=init.device) * t1
357
+ x = last_step_fn(xs[-1], ts, model, **model_kwargs)
358
+ xs.append(x)
359
+
360
+ assert len(xs) == num_steps, "Samples does not match the number of steps"
361
+
362
+ return xs
363
+
364
+ return _sample
365
+
366
+ def sample_dpm(
367
+ self,
368
+ model,
369
+ model_kwargs=None,
370
+ ):
371
+
372
+ noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
373
+
374
+ def noise_pred_fn(x, t_continuous):
375
+ output = model(x, 1 - t_continuous, **model_kwargs)
376
+ _, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
377
+ try:
378
+ noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output
379
+ except:
380
+ noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0]
381
+ return noise
382
+
383
+ return DPM_Solver(noise_pred_fn, noise_schedule, algorithm_type="dpmsolver++").sample
384
+
385
+
386
+ def sample_ode(
387
+ self,
388
+ *,
389
+ sampling_method="dopri5",
390
+ num_steps=50,
391
+ atol=1e-6,
392
+ rtol=1e-3,
393
+ reverse=False,
394
+ do_shift=False,
395
+ time_shifting_factor=None,
396
+ ):
397
+ """returns a sampling function with given ODE settings
398
+ Args:
399
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
400
+ - num_steps:
401
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
402
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
403
+ - atol: absolute error tolerance for the solver
404
+ - rtol: relative error tolerance for the solver
405
+ """
406
+
407
+ # for flux
408
+ drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs)
409
+
410
+ t0, t1 = self.transport.check_interval(
411
+ self.transport.train_eps,
412
+ self.transport.sample_eps,
413
+ sde=False,
414
+ eval=True,
415
+ reverse=reverse,
416
+ last_step_size=0.0,
417
+ )
418
+
419
+ _ode = ode(
420
+ drift=drift,
421
+ t0=t0,
422
+ t1=t1,
423
+ sampler_type=sampling_method,
424
+ num_steps=num_steps,
425
+ atol=atol,
426
+ rtol=rtol,
427
+ do_shift=do_shift,
428
+ time_shifting_factor=time_shifting_factor,
429
+ )
430
+
431
+ return _ode.sample
432
+
433
+ def sample_ode_likelihood(
434
+ self,
435
+ *,
436
+ sampling_method="dopri5",
437
+ num_steps=50,
438
+ atol=1e-6,
439
+ rtol=1e-3,
440
+ ):
441
+ """returns a sampling function for calculating likelihood with given ODE settings
442
+ Args:
443
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
444
+ - num_steps:
445
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
446
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
447
+ - atol: absolute error tolerance for the solver
448
+ - rtol: relative error tolerance for the solver
449
+ """
450
+
451
+ def _likelihood_drift(x, t, model, **model_kwargs):
452
+ x, _ = x
453
+ eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
454
+ t = th.ones_like(t) * (1 - t)
455
+ with th.enable_grad():
456
+ x.requires_grad = True
457
+ grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
458
+ logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
459
+ drift = self.drift(x, t, model, **model_kwargs)
460
+ return (-drift, logp_grad)
461
+
462
+ t0, t1 = self.transport.check_interval(
463
+ self.transport.train_eps,
464
+ self.transport.sample_eps,
465
+ sde=False,
466
+ eval=True,
467
+ reverse=False,
468
+ last_step_size=0.0,
469
+ )
470
+
471
+ _ode = ode(
472
+ drift=_likelihood_drift,
473
+ t0=t0,
474
+ t1=t1,
475
+ sampler_type=sampling_method,
476
+ num_steps=num_steps,
477
+ atol=atol,
478
+ rtol=rtol,
479
+ )
480
+
481
+ def _sample_fn(x, model, **model_kwargs):
482
+ init_logp = th.zeros(x.size(0)).to(x)
483
+ input = (x, init_logp)
484
+ drift, delta_logp = _ode.sample(input, model, **model_kwargs)
485
+ drift, delta_logp = drift[-1], delta_logp[-1]
486
+ prior_logp = self.transport.prior_logp(drift)
487
+ logp = prior_logp - delta_logp
488
+ return logp, drift
489
+
490
+ return _sample_fn
transport/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import math
3
+
4
+ class EasyDict:
5
+ def __init__(self, sub_dict):
6
+ for k, v in sub_dict.items():
7
+ setattr(self, k, v)
8
+
9
+ def __getitem__(self, key):
10
+ return getattr(self, key)
11
+
12
+
13
+ def mean_flat(x):
14
+ """
15
+ Take the mean over all non-batch dimensions.
16
+ """
17
+ return th.mean(x, dim=list(range(1, len(x.size()))))
18
+
19
+
20
+ def log_state(state):
21
+ result = []
22
+
23
+ sorted_state = dict(sorted(state.items()))
24
+ for key, value in sorted_state.items():
25
+ # Check if the value is an instance of a class
26
+ if "<object" in str(value) or "object at" in str(value):
27
+ result.append(f"{key}: [{value.__class__.__name__}]")
28
+ else:
29
+ result.append(f"{key}: {value}")
30
+
31
+ return "\n".join(result)
32
+
33
+ def time_shift(mu: float, sigma: float, t: th.Tensor):
34
+ # the following implementation was original for t=0: clean / t=1: noise
35
+ # Since we adopt the reverse, the 1-t operations are needed
36
+ t = 1 - t
37
+ t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
38
+ t = 1 - t
39
+ return t
40
+
41
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
42
+ m = (y2 - y1) / (x2 - x1)
43
+ b = y1 - m * x1
44
+ return lambda x: m * x + b
45
+
46
+ def expand_dims(v, dims):
47
+ """
48
+ Expand the tensor `v` to the dim `dims`.
49
+
50
+ Args:
51
+ `v`: a PyTorch tensor with shape [N].
52
+ `dim`: a `int`.
53
+ Returns:
54
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
55
+ """
56
+ return v[(...,) + (None,) * (dims - 1)]