barakmeiri commited on
Commit
2dddb52
·
verified ·
1 Parent(s): 12c346b

Upload 4 files

Browse files
RealTimeEditingNotebook.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is based on ReNoise https://github.com/garibida/ReNoise-Inversion
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass
7
+ class RunConfig:
8
+ num_inference_steps: int = 4
9
+
10
+ num_inversion_steps: int = 100
11
+
12
+ guidance_scale: float = 0.0
13
+
14
+ inversion_max_step: float = 1.0
15
+
16
+ def __post_init__(self):
17
+ pass
src/euler_scheduler.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is based on ReNoise https://github.com/garibida/ReNoise-Inversion
2
+
3
+ from diffusers import EulerAncestralDiscreteScheduler
4
+ from diffusers.utils import BaseOutput
5
+ import torch
6
+ from typing import List, Optional, Tuple, Union
7
+ import numpy as np
8
+
9
+ from src.eunms import Epsilon_Update_Type
10
+
11
+ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
12
+ """
13
+ Output class for the scheduler's `step` function output.
14
+
15
+ Args:
16
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
17
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
18
+ denoising loop.
19
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
20
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
21
+ `pred_original_sample` can be used to preview progress or for guidance.
22
+ """
23
+
24
+ prev_sample: torch.FloatTensor
25
+ pred_original_sample: Optional[torch.FloatTensor] = None
26
+
27
+ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
28
+ def set_noise_list(self, noise_list):
29
+ self.noise_list = noise_list
30
+
31
+ def get_noise_to_remove(self):
32
+ sigma_from = self.sigmas[self.step_index]
33
+ sigma_to = self.sigmas[self.step_index + 1]
34
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
35
+
36
+ return self.noise_list[self.step_index] * sigma_up\
37
+
38
+ def scale_model_input(
39
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
40
+ ) -> torch.FloatTensor:
41
+ """
42
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
43
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
44
+
45
+ Args:
46
+ sample (`torch.FloatTensor`):
47
+ The input sample.
48
+ timestep (`int`, *optional*):
49
+ The current timestep in the diffusion chain.
50
+
51
+ Returns:
52
+ `torch.FloatTensor`:
53
+ A scaled input sample.
54
+ """
55
+
56
+ self._init_step_index(timestep.view((1)))
57
+ return EulerAncestralDiscreteScheduler.scale_model_input(self, sample, timestep)
58
+
59
+
60
+ def step(
61
+ self,
62
+ model_output: torch.FloatTensor,
63
+ timestep: Union[float, torch.FloatTensor],
64
+ sample: torch.FloatTensor,
65
+ generator: Optional[torch.Generator] = None,
66
+ return_dict: bool = True,
67
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
68
+ """
69
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
70
+ process from the learned model outputs (most often the predicted noise).
71
+
72
+ Args:
73
+ model_output (`torch.FloatTensor`):
74
+ The direct output from learned diffusion model.
75
+ timestep (`float`):
76
+ The current discrete timestep in the diffusion chain.
77
+ sample (`torch.FloatTensor`):
78
+ A current instance of a sample created by the diffusion process.
79
+ generator (`torch.Generator`, *optional*):
80
+ A random number generator.
81
+ return_dict (`bool`):
82
+ Whether or not to return a
83
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
84
+
85
+ Returns:
86
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
87
+ If return_dict is `True`,
88
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
89
+ otherwise a tuple is returned where the first element is the sample tensor.
90
+
91
+ """
92
+
93
+ if (
94
+ isinstance(timestep, int)
95
+ or isinstance(timestep, torch.IntTensor)
96
+ or isinstance(timestep, torch.LongTensor)
97
+ ):
98
+ raise ValueError(
99
+ (
100
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
101
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
102
+ " one of the `scheduler.timesteps` as a timestep."
103
+ ),
104
+ )
105
+
106
+ if not self.is_scale_input_called:
107
+ logger.warning(
108
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
109
+ "See `StableDiffusionPipeline` for a usage example."
110
+ )
111
+
112
+ self._init_step_index(timestep.view((1)))
113
+
114
+ sigma = self.sigmas[self.step_index]
115
+
116
+ # Upcast to avoid precision issues when computing prev_sample
117
+ sample = sample.to(torch.float32)
118
+
119
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
120
+ if self.config.prediction_type == "epsilon":
121
+ pred_original_sample = sample - sigma * model_output
122
+ elif self.config.prediction_type == "v_prediction":
123
+ # * c_out + input * c_skip
124
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
125
+ elif self.config.prediction_type == "sample":
126
+ raise NotImplementedError("prediction_type not implemented yet: sample")
127
+ else:
128
+ raise ValueError(
129
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
130
+ )
131
+
132
+ sigma_from = self.sigmas[self.step_index]
133
+ sigma_to = self.sigmas[self.step_index + 1]
134
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
135
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
136
+
137
+ # 2. Convert to an ODE derivative
138
+ # derivative = (sample - pred_original_sample) / sigma
139
+ derivative = model_output
140
+
141
+ dt = sigma_down - sigma
142
+
143
+ prev_sample = sample + derivative * dt
144
+
145
+ device = model_output.device
146
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
147
+ # prev_sample = prev_sample + noise * sigma_up
148
+
149
+ prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
150
+
151
+ # Cast sample back to model compatible dtype
152
+ prev_sample = prev_sample.to(model_output.dtype)
153
+
154
+ # upon completion increase step index by one
155
+ self._step_index += 1
156
+
157
+ if not return_dict:
158
+ return (prev_sample,)
159
+
160
+ return EulerAncestralDiscreteSchedulerOutput(
161
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
162
+ )
163
+
164
+ def step_and_update_noise(
165
+ self,
166
+ model_output: torch.FloatTensor,
167
+ timestep: Union[float, torch.FloatTensor],
168
+ sample: torch.FloatTensor,
169
+ expected_prev_sample: torch.FloatTensor,
170
+ update_epsilon_type=Epsilon_Update_Type.OVERRIDE,
171
+ generator: Optional[torch.Generator] = None,
172
+ return_dict: bool = True,
173
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
174
+ """
175
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
176
+ process from the learned model outputs (most often the predicted noise).
177
+
178
+ Args:
179
+ model_output (`torch.FloatTensor`):
180
+ The direct output from learned diffusion model.
181
+ timestep (`float`):
182
+ The current discrete timestep in the diffusion chain.
183
+ sample (`torch.FloatTensor`):
184
+ A current instance of a sample created by the diffusion process.
185
+ generator (`torch.Generator`, *optional*):
186
+ A random number generator.
187
+ return_dict (`bool`):
188
+ Whether or not to return a
189
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
190
+
191
+ Returns:
192
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
193
+ If return_dict is `True`,
194
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
195
+ otherwise a tuple is returned where the first element is the sample tensor.
196
+
197
+ """
198
+
199
+ if (
200
+ isinstance(timestep, int)
201
+ or isinstance(timestep, torch.IntTensor)
202
+ or isinstance(timestep, torch.LongTensor)
203
+ ):
204
+ raise ValueError(
205
+ (
206
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
207
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
208
+ " one of the `scheduler.timesteps` as a timestep."
209
+ ),
210
+ )
211
+
212
+ if not self.is_scale_input_called:
213
+ logger.warning(
214
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
215
+ "See `StableDiffusionPipeline` for a usage example."
216
+ )
217
+
218
+ self._init_step_index(timestep.view((1)))
219
+
220
+ sigma = self.sigmas[self.step_index]
221
+
222
+ # Upcast to avoid precision issues when computing prev_sample
223
+ sample = sample.to(torch.float32)
224
+
225
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
226
+ if self.config.prediction_type == "epsilon":
227
+ pred_original_sample = sample - sigma * model_output
228
+ elif self.config.prediction_type == "v_prediction":
229
+ # * c_out + input * c_skip
230
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
231
+ elif self.config.prediction_type == "sample":
232
+ raise NotImplementedError("prediction_type not implemented yet: sample")
233
+ else:
234
+ raise ValueError(
235
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
236
+ )
237
+
238
+ sigma_from = self.sigmas[self.step_index]
239
+ sigma_to = self.sigmas[self.step_index + 1]
240
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
241
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
242
+
243
+ # 2. Convert to an ODE derivative
244
+ # derivative = (sample - pred_original_sample) / sigma
245
+ derivative = model_output
246
+
247
+ dt = sigma_down - sigma
248
+
249
+ prev_sample = sample + derivative * dt
250
+
251
+ device = model_output.device
252
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
253
+ # prev_sample = prev_sample + noise * sigma_up
254
+
255
+ if sigma_up > 0:
256
+ req_noise = (expected_prev_sample - prev_sample) / sigma_up
257
+ if update_epsilon_type == Epsilon_Update_Type.OVERRIDE:
258
+ self.noise_list[self.step_index] = req_noise
259
+ else:
260
+ for i in range(10):
261
+ n = torch.autograd.Variable(self.noise_list[self.step_index].detach().clone(), requires_grad=True)
262
+ loss = torch.norm(n - req_noise.detach())
263
+ loss.backward()
264
+ self.noise_list[self.step_index] -= n.grad.detach() * 1.8
265
+
266
+
267
+ prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
268
+
269
+ # Cast sample back to model compatible dtype
270
+ prev_sample = prev_sample.to(model_output.dtype)
271
+
272
+ # upon completion increase step index by one
273
+ self._step_index += 1
274
+
275
+ if not return_dict:
276
+ return (prev_sample,)
277
+
278
+ return EulerAncestralDiscreteSchedulerOutput(
279
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
280
+ )
281
+
282
+ def inv_step(
283
+ self,
284
+ model_output: torch.FloatTensor,
285
+ timestep: Union[float, torch.FloatTensor],
286
+ sample: torch.FloatTensor,
287
+ generator: Optional[torch.Generator] = None,
288
+ return_dict: bool = True,
289
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
290
+ """
291
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
292
+ process from the learned model outputs (most often the predicted noise).
293
+
294
+ Args:
295
+ model_output (`torch.FloatTensor`):
296
+ The direct output from learned diffusion model.
297
+ timestep (`float`):
298
+ The current discrete timestep in the diffusion chain.
299
+ sample (`torch.FloatTensor`):
300
+ A current instance of a sample created by the diffusion process.
301
+ generator (`torch.Generator`, *optional*):
302
+ A random number generator.
303
+ return_dict (`bool`):
304
+ Whether or not to return a
305
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
306
+
307
+ Returns:
308
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
309
+ If return_dict is `True`,
310
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
311
+ otherwise a tuple is returned where the first element is the sample tensor.
312
+
313
+ """
314
+
315
+ if (
316
+ isinstance(timestep, int)
317
+ or isinstance(timestep, torch.IntTensor)
318
+ or isinstance(timestep, torch.LongTensor)
319
+ ):
320
+ raise ValueError(
321
+ (
322
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
323
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
324
+ " one of the `scheduler.timesteps` as a timestep."
325
+ ),
326
+ )
327
+
328
+ if not self.is_scale_input_called:
329
+ logger.warning(
330
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
331
+ "See `StableDiffusionPipeline` for a usage example."
332
+ )
333
+
334
+ self._init_step_index(timestep.view((1)))
335
+
336
+ sigma = self.sigmas[self.step_index]
337
+
338
+ # Upcast to avoid precision issues when computing prev_sample
339
+ sample = sample.to(torch.float32)
340
+
341
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
342
+ if self.config.prediction_type == "epsilon":
343
+ pred_original_sample = sample - sigma * model_output
344
+ elif self.config.prediction_type == "v_prediction":
345
+ # * c_out + input * c_skip
346
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
347
+ elif self.config.prediction_type == "sample":
348
+ raise NotImplementedError("prediction_type not implemented yet: sample")
349
+ else:
350
+ raise ValueError(
351
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
352
+ )
353
+
354
+ sigma_from = self.sigmas[self.step_index]
355
+ sigma_to = self.sigmas[self.step_index+1]
356
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
357
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
358
+ # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
359
+ sigma_down = sigma_to**2 / sigma_from
360
+
361
+ # 2. Convert to an ODE derivative
362
+ # derivative = (sample - pred_original_sample) / sigma
363
+ derivative = model_output
364
+
365
+ dt = sigma_down - sigma
366
+ # dt = sigma_down - sigma_from
367
+
368
+ prev_sample = sample - derivative * dt
369
+
370
+ device = model_output.device
371
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
372
+ # prev_sample = prev_sample + noise * sigma_up
373
+
374
+ prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
375
+
376
+ # Cast sample back to model compatible dtype
377
+ prev_sample = prev_sample.to(model_output.dtype)
378
+
379
+ # upon completion increase step index by one
380
+ self._step_index += 1
381
+
382
+ if not return_dict:
383
+ return (prev_sample,)
384
+
385
+ return EulerAncestralDiscreteSchedulerOutput(
386
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
387
+ )
388
+
389
+ def get_all_sigmas(self) -> torch.FloatTensor:
390
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
391
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
392
+ return torch.from_numpy(sigmas)
393
+
394
+ def add_noise_off_schedule(
395
+ self,
396
+ original_samples: torch.FloatTensor,
397
+ noise: torch.FloatTensor,
398
+ timesteps: torch.FloatTensor,
399
+ ) -> torch.FloatTensor:
400
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
401
+ sigmas = self.get_all_sigmas()
402
+ sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
403
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
404
+ # mps does not support float64
405
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
406
+ else:
407
+ timesteps = timesteps.to(original_samples.device)
408
+
409
+ step_indices = 1000 - int(timesteps.item())
410
+
411
+ sigma = sigmas[step_indices].flatten()
412
+ while len(sigma.shape) < len(original_samples.shape):
413
+ sigma = sigma.unsqueeze(-1)
414
+
415
+ noisy_samples = original_samples + noise * sigma
416
+ return noisy_samples
417
+
418
+ # def update_noise_for_friendly_inversion(
419
+ # self,
420
+ # model_output: torch.FloatTensor,
421
+ # timestep: Union[float, torch.FloatTensor],
422
+ # z_t: torch.FloatTensor,
423
+ # z_tp1: torch.FloatTensor,
424
+ # return_dict: bool = True,
425
+ # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
426
+ # if (
427
+ # isinstance(timestep, int)
428
+ # or isinstance(timestep, torch.IntTensor)
429
+ # or isinstance(timestep, torch.LongTensor)
430
+ # ):
431
+ # raise ValueError(
432
+ # (
433
+ # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
434
+ # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
435
+ # " one of the `scheduler.timesteps` as a timestep."
436
+ # ),
437
+ # )
438
+
439
+ # if not self.is_scale_input_called:
440
+ # logger.warning(
441
+ # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
442
+ # "See `StableDiffusionPipeline` for a usage example."
443
+ # )
444
+
445
+ # self._init_step_index(timestep.view((1)))
446
+
447
+ # sigma = self.sigmas[self.step_index]
448
+
449
+ # sigma_from = self.sigmas[self.step_index]
450
+ # sigma_to = self.sigmas[self.step_index+1]
451
+ # # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
452
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
453
+ # # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
454
+ # sigma_down = sigma_to**2 / sigma_from
455
+
456
+ # # 2. Conv = (sample - pred_original_sample) / sigma
457
+ # derivative = model_output
458
+
459
+ # dt = sigma_down - sigma
460
+ # # dt = sigma_down - sigma_from
461
+
462
+ # prev_sample = z_t - derivative * dt
463
+
464
+ # if sigma_up > 0:
465
+ # self.noise_list[self.step_index] = (prev_sample - z_tp1) / sigma_up
466
+
467
+ # prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
468
+
469
+
470
+ # if not return_dict:
471
+ # return (prev_sample,)
472
+
473
+ # return EulerAncestralDiscreteSchedulerOutput(
474
+ # prev_sample=prev_sample, pred_original_sample=None
475
+ # )
476
+
477
+
478
+ # def step_friendly_inversion(
479
+ # self,
480
+ # model_output: torch.FloatTensor,
481
+ # timestep: Union[float, torch.FloatTensor],
482
+ # sample: torch.FloatTensor,
483
+ # generator: Optional[torch.Generator] = None,
484
+ # return_dict: bool = True,
485
+ # expected_next_sample: torch.FloatTensor = None,
486
+ # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
487
+ # """
488
+ # Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
489
+ # process from the learned model outputs (most often the predicted noise).
490
+
491
+ # Args:
492
+ # model_output (`torch.FloatTensor`):
493
+ # The direct output from learned diffusion model.
494
+ # timestep (`float`):
495
+ # The current discrete timestep in the diffusion chain.
496
+ # sample (`torch.FloatTensor`):
497
+ # A current instance of a sample created by the diffusion process.
498
+ # generator (`torch.Generator`, *optional*):
499
+ # A random number generator.
500
+ # return_dict (`bool`):
501
+ # Whether or not to return a
502
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
503
+
504
+ # Returns:
505
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
506
+ # If return_dict is `True`,
507
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
508
+ # otherwise a tuple is returned where the first element is the sample tensor.
509
+
510
+ # """
511
+
512
+ # if (
513
+ # isinstance(timestep, int)
514
+ # or isinstance(timestep, torch.IntTensor)
515
+ # or isinstance(timestep, torch.LongTensor)
516
+ # ):
517
+ # raise ValueError(
518
+ # (
519
+ # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
520
+ # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
521
+ # " one of the `scheduler.timesteps` as a timestep."
522
+ # ),
523
+ # )
524
+
525
+ # if not self.is_scale_input_called:
526
+ # logger.warning(
527
+ # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
528
+ # "See `StableDiffusionPipeline` for a usage example."
529
+ # )
530
+
531
+ # self._init_step_index(timestep.view((1)))
532
+
533
+ # sigma = self.sigmas[self.step_index]
534
+
535
+ # # Upcast to avoid precision issues when computing prev_sample
536
+ # sample = sample.to(torch.float32)
537
+
538
+ # # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
539
+ # if self.config.prediction_type == "epsilon":
540
+ # pred_original_sample = sample - sigma * model_output
541
+ # elif self.config.prediction_type == "v_prediction":
542
+ # # * c_out + input * c_skip
543
+ # pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
544
+ # elif self.config.prediction_type == "sample":
545
+ # raise NotImplementedError("prediction_type not implemented yet: sample")
546
+ # else:
547
+ # raise ValueError(
548
+ # f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
549
+ # )
550
+
551
+ # sigma_from = self.sigmas[self.step_index]
552
+ # sigma_to = self.sigmas[self.step_index + 1]
553
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
554
+ # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
555
+
556
+ # # 2. Convert to an ODE derivative
557
+ # # derivative = (sample - pred_original_sample) / sigma
558
+ # derivative = model_output
559
+
560
+ # dt = sigma_down - sigma
561
+
562
+ # prev_sample = sample + derivative * dt
563
+
564
+ # device = model_output.device
565
+ # # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
566
+ # # prev_sample = prev_sample + noise * sigma_up
567
+
568
+ # if sigma_up > 0:
569
+ # self.noise_list[self.step_index] = (expected_next_sample - prev_sample) / sigma_up
570
+
571
+ # prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
572
+
573
+ # # Cast sample back to model compatible dtype
574
+ # prev_sample = prev_sample.to(model_output.dtype)
575
+
576
+ # # upon completion increase step index by one
577
+ # self._step_index += 1
578
+
579
+ # if not return_dict:
580
+ # return (prev_sample,)
581
+
582
+ # return EulerAncestralDiscreteSchedulerOutput(
583
+ # prev_sample=prev_sample, pred_original_sample=pred_original_sample
584
+ # )
src/sdxl_inversion_pipeline.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code is based on ReNoise https://github.com/garibida/ReNoise-Inversion
2
+
3
+ import torch
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ from diffusers import (
7
+ StableDiffusionXLImg2ImgPipeline,
8
+ )
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+
11
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
12
+ StableDiffusionXLPipelineOutput,
13
+ retrieve_timesteps,
14
+ PipelineImageInput
15
+ )
16
+
17
+ from src.eunms import Epsilon_Update_Type
18
+
19
+
20
+ def _backward_ddim(x_tm1, alpha_t, alpha_tm1, eps_xt):
21
+ """
22
+ let a = alpha_t, b = alpha_{t - 1}
23
+ We have a > b,
24
+ x_{t} - x_{t - 1} = sqrt(a) ((sqrt(1/b) - sqrt(1/a)) * x_{t-1} + (sqrt(1/a - 1) - sqrt(1/b - 1)) * eps_{t-1})
25
+ From https://arxiv.org/pdf/2105.05233.pdf, section F.
26
+ """
27
+
28
+ a, b = alpha_t, alpha_tm1
29
+ sa = a ** 0.5
30
+ sb = b ** 0.5
31
+
32
+ return sa * ((1 / sb) * x_tm1 + ((1 / a - 1) ** 0.5 - (1 / b - 1) ** 0.5) * eps_xt)
33
+
34
+
35
+ class SDXLDDIMPipeline(StableDiffusionXLImg2ImgPipeline):
36
+ # @torch.no_grad()
37
+ def __call__(
38
+ self,
39
+ prompt: Union[str, List[str]] = None,
40
+ prompt_2: Optional[Union[str, List[str]]] = None,
41
+ image: PipelineImageInput = None,
42
+ strength: float = 0.3,
43
+ num_inversion_steps: int = 50,
44
+ timesteps: List[int] = None,
45
+ denoising_start: Optional[float] = None,
46
+ denoising_end: Optional[float] = None,
47
+ guidance_scale: float = 1.0,
48
+ negative_prompt: Optional[Union[str, List[str]]] = None,
49
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
50
+ num_images_per_prompt: Optional[int] = 1,
51
+ eta: float = 0.0,
52
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
53
+ latents: Optional[torch.FloatTensor] = None,
54
+ prompt_embeds: Optional[torch.FloatTensor] = None,
55
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
56
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
57
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
58
+ ip_adapter_image: Optional[PipelineImageInput] = None,
59
+ output_type: Optional[str] = "pil",
60
+ return_dict: bool = True,
61
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
62
+ guidance_rescale: float = 0.0,
63
+ original_size: Tuple[int, int] = None,
64
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
65
+ target_size: Tuple[int, int] = None,
66
+ negative_original_size: Optional[Tuple[int, int]] = None,
67
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
68
+ negative_target_size: Optional[Tuple[int, int]] = None,
69
+ aesthetic_score: float = 6.0,
70
+ negative_aesthetic_score: float = 2.5,
71
+ clip_skip: Optional[int] = None,
72
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
73
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
74
+ num_inference_steps: int = 50,
75
+ inv_hp=None,
76
+ **kwargs,
77
+ ):
78
+ callback = kwargs.pop("callback", None)
79
+ callback_steps = kwargs.pop("callback_steps", None)
80
+
81
+ if callback is not None:
82
+ deprecate(
83
+ "callback",
84
+ "1.0.0",
85
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
86
+ )
87
+ if callback_steps is not None:
88
+ deprecate(
89
+ "callback_steps",
90
+ "1.0.0",
91
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
92
+ )
93
+
94
+ # 1. Check inputs. Raise error if not correct
95
+ self.check_inputs(
96
+ prompt,
97
+ prompt_2,
98
+ strength,
99
+ num_inversion_steps,
100
+ callback_steps,
101
+ negative_prompt,
102
+ negative_prompt_2,
103
+ prompt_embeds,
104
+ negative_prompt_embeds,
105
+ callback_on_step_end_tensor_inputs,
106
+ )
107
+
108
+ denoising_start_fr = 1.0 - denoising_start
109
+ denoising_start = denoising_start
110
+
111
+ self._guidance_scale = guidance_scale
112
+ self._guidance_rescale = guidance_rescale
113
+ self._clip_skip = clip_skip
114
+ self._cross_attention_kwargs = cross_attention_kwargs
115
+ self._denoising_end = denoising_end
116
+ self._denoising_start = denoising_start
117
+
118
+ # 2. Define call parameters
119
+ if prompt is not None and isinstance(prompt, str):
120
+ batch_size = 1
121
+ elif prompt is not None and isinstance(prompt, list):
122
+ batch_size = len(prompt)
123
+ else:
124
+ batch_size = prompt_embeds.shape[0]
125
+
126
+ device = self._execution_device
127
+
128
+ # 3. Encode input prompt
129
+ text_encoder_lora_scale = (
130
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
131
+ )
132
+ (
133
+ prompt_embeds,
134
+ negative_prompt_embeds,
135
+ pooled_prompt_embeds,
136
+ negative_pooled_prompt_embeds,
137
+ ) = self.encode_prompt(
138
+ prompt=prompt,
139
+ prompt_2=prompt_2,
140
+ device=device,
141
+ num_images_per_prompt=num_images_per_prompt,
142
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
143
+ negative_prompt=negative_prompt,
144
+ negative_prompt_2=negative_prompt_2,
145
+ prompt_embeds=prompt_embeds,
146
+ negative_prompt_embeds=negative_prompt_embeds,
147
+ pooled_prompt_embeds=pooled_prompt_embeds,
148
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
149
+ lora_scale=text_encoder_lora_scale,
150
+ clip_skip=self.clip_skip,
151
+ )
152
+
153
+ # 4. Preprocess image
154
+ image = self.image_processor.preprocess(image)
155
+
156
+ # 5. Prepare timesteps
157
+ def denoising_value_valid(dnv):
158
+ return isinstance(self.denoising_end, float) and 0 < dnv < 1
159
+
160
+ timesteps, num_inversion_steps = retrieve_timesteps(self.scheduler, num_inversion_steps, device, timesteps)
161
+ timesteps_num_inference_steps, num_inference_steps = retrieve_timesteps(self.scheduler_inference,
162
+ num_inference_steps, device, None)
163
+
164
+ timesteps, num_inversion_steps = self.get_timesteps(
165
+ num_inversion_steps,
166
+ strength,
167
+ device,
168
+ denoising_start=self.denoising_start if denoising_value_valid else None,
169
+ )
170
+ # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
171
+
172
+ # add_noise = True if self.denoising_start is None else False
173
+ # 6. Prepare latent variables
174
+ with torch.no_grad():
175
+ latents = self.prepare_latents(
176
+ image,
177
+ None,
178
+ batch_size,
179
+ num_images_per_prompt,
180
+ prompt_embeds.dtype,
181
+ device,
182
+ generator,
183
+ False,
184
+ )
185
+ # 7. Prepare extra step kwargs.
186
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
187
+
188
+ height, width = latents.shape[-2:]
189
+ height = height * self.vae_scale_factor
190
+ width = width * self.vae_scale_factor
191
+
192
+ original_size = original_size or (height, width)
193
+ target_size = target_size or (height, width)
194
+
195
+ # 8. Prepare added time ids & embeddings
196
+ if negative_original_size is None:
197
+ negative_original_size = original_size
198
+ if negative_target_size is None:
199
+ negative_target_size = target_size
200
+
201
+ add_text_embeds = pooled_prompt_embeds
202
+ if self.text_encoder_2 is None:
203
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
204
+ else:
205
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
206
+
207
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
208
+ original_size,
209
+ crops_coords_top_left,
210
+ target_size,
211
+ aesthetic_score,
212
+ negative_aesthetic_score,
213
+ negative_original_size,
214
+ negative_crops_coords_top_left,
215
+ negative_target_size,
216
+ dtype=prompt_embeds.dtype,
217
+ text_encoder_projection_dim=text_encoder_projection_dim,
218
+ )
219
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
220
+
221
+ if self.do_classifier_free_guidance:
222
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
223
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
224
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
225
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
226
+
227
+ prompt_embeds = prompt_embeds.to(device)
228
+ add_text_embeds = add_text_embeds.to(device)
229
+ add_time_ids = add_time_ids.to(device)
230
+
231
+ if ip_adapter_image is not None:
232
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
233
+ if self.do_classifier_free_guidance:
234
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
235
+ image_embeds = image_embeds.to(device)
236
+
237
+ # 9. Denoising loop
238
+ num_warmup_steps = max(len(timesteps) - num_inversion_steps * self.scheduler.order, 0)
239
+ prev_timestep = None
240
+
241
+ self._num_timesteps = len(timesteps)
242
+ self.prev_z = torch.clone(latents)
243
+ self.prev_z4 = torch.clone(latents)
244
+ self.z_0 = torch.clone(latents)
245
+ g_cpu = torch.Generator().manual_seed(7865)
246
+ self.noise = randn_tensor(self.z_0.shape, generator=g_cpu, device=self.z_0.device, dtype=self.z_0.dtype)
247
+
248
+ # Friendly inversion params
249
+ timesteps_for = reversed(timesteps)
250
+ noise = randn_tensor(latents.shape, generator=g_cpu, device=latents.device, dtype=latents.dtype)
251
+ #latents = latents
252
+ z_T = latents.clone()
253
+
254
+ all_latents = [latents.clone()]
255
+ with self.progress_bar(total=num_inversion_steps) as progress_bar:
256
+ for i, t in enumerate(timesteps_for):
257
+
258
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
259
+ if ip_adapter_image is not None:
260
+ added_cond_kwargs["image_embeds"] = image_embeds
261
+
262
+ z_tp1 = self.inversion_step(latents,
263
+ t,
264
+ prompt_embeds,
265
+ added_cond_kwargs,
266
+ prev_timestep=prev_timestep,
267
+ inv_hp=inv_hp,
268
+ z_0=self.z_0)
269
+
270
+ prev_timestep = t
271
+ latents = z_tp1
272
+
273
+ all_latents.append(latents.clone())
274
+
275
+ if callback_on_step_end is not None:
276
+ callback_kwargs = {}
277
+ for k in callback_on_step_end_tensor_inputs:
278
+ callback_kwargs[k] = locals()[k]
279
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
280
+
281
+ latents = callback_outputs.pop("latents", latents)
282
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
283
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
284
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
285
+ negative_pooled_prompt_embeds = callback_outputs.pop(
286
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
287
+ )
288
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
289
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
290
+
291
+ # call the callback, if provided
292
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
293
+ progress_bar.update()
294
+ if callback is not None and i % callback_steps == 0:
295
+ step_idx = i // getattr(self.scheduler, "order", 1)
296
+ callback(step_idx, t, latents)
297
+
298
+ image = latents
299
+
300
+ # Offload all models
301
+ self.maybe_free_model_hooks()
302
+
303
+ return StableDiffusionXLPipelineOutput(images=image), all_latents
304
+
305
+ def get_timestamp_dist(self, z_0, timesteps):
306
+ timesteps = timesteps.to(z_0.device)
307
+ sigma = self.scheduler.sigmas.cuda()[:-1][self.scheduler.timesteps == timesteps]
308
+ z_0 = z_0.reshape(-1, 1)
309
+
310
+ def gaussian_pdf(x):
311
+ shape = x.shape
312
+ x = x.reshape(-1, 1)
313
+ all_probs = - 0.5 * torch.pow(((x - z_0) / sigma), 2)
314
+ return all_probs.reshape(shape)
315
+
316
+ return gaussian_pdf
317
+
318
+ # @torch.no_grad()
319
+ def inversion_step(
320
+ self,
321
+ z_t: torch.tensor,
322
+ t: torch.tensor,
323
+ prompt_embeds,
324
+ added_cond_kwargs,
325
+ prev_timestep: Optional[torch.tensor] = None,
326
+ inv_hp=None,
327
+ z_0=None,
328
+ ) -> torch.tensor:
329
+
330
+ n_iters, alpha, lr = inv_hp
331
+ latent = z_t
332
+ best_latent = None
333
+ best_score = torch.inf
334
+ curr_dist = self.get_timestamp_dist(z_0, t)
335
+ for i in range(n_iters):
336
+ latent.requires_grad = True
337
+ noise_pred = self.unet_pass(latent, t, prompt_embeds, added_cond_kwargs)
338
+
339
+ next_latent = self.backward_step(noise_pred, t, z_t, prev_timestep)
340
+ f_x = (next_latent - latent).abs() - alpha * curr_dist(next_latent)
341
+ score = f_x.mean()
342
+
343
+ if score < best_score:
344
+ best_score = score
345
+ best_latent = next_latent.detach()
346
+
347
+ f_x.sum().backward()
348
+ latent = latent - lr * (f_x / latent.grad)
349
+ latent.grad = None
350
+ latent._grad_fn = None
351
+
352
+ # if self.cfg.update_epsilon_type != Epsilon_Update_Type.NONE:
353
+ # noise_pred = self.unet_pass(best_latent, t, prompt_embeds, added_cond_kwargs)
354
+ # self.scheduler.step_and_update_noise(noise_pred, t, best_latent, z_t, return_dict=False,
355
+ # update_epsilon_type=self.cfg.update_epsilon_type)
356
+ return best_latent
357
+
358
+ @torch.no_grad()
359
+ def unet_pass(self, z_t, t, prompt_embeds, added_cond_kwargs):
360
+ latent_model_input = torch.cat([z_t] * 2) if self.do_classifier_free_guidance else z_t
361
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
362
+ return self.unet(
363
+ latent_model_input,
364
+ t,
365
+ encoder_hidden_states=prompt_embeds,
366
+ timestep_cond=None,
367
+ cross_attention_kwargs=self.cross_attention_kwargs,
368
+ added_cond_kwargs=added_cond_kwargs,
369
+ return_dict=False,
370
+ )[0]
371
+
372
+ @torch.no_grad()
373
+ def backward_step(self, nosie_pred, t, z_t, prev_timestep):
374
+ extra_step_kwargs = {}
375
+ return self.scheduler.inv_step(nosie_pred, t, z_t, **extra_step_kwargs, return_dict=False)[0].detach()