barakmeiri commited on
Commit
424f073
·
verified ·
1 Parent(s): 3149a3c

Update src/euler_scheduler.py

Browse files

added function set_noise_list_device(device)

Files changed (1) hide show
  1. src/euler_scheduler.py +589 -583
src/euler_scheduler.py CHANGED
@@ -1,584 +1,590 @@
1
- # Code is based on ReNoise https://github.com/garibida/ReNoise-Inversion
2
-
3
- from diffusers import EulerAncestralDiscreteScheduler
4
- from diffusers.utils import BaseOutput
5
- import torch
6
- from typing import List, Optional, Tuple, Union
7
- import numpy as np
8
-
9
- from src.eunms import Epsilon_Update_Type
10
-
11
- class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
12
- """
13
- Output class for the scheduler's `step` function output.
14
-
15
- Args:
16
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
17
- Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
18
- denoising loop.
19
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
20
- The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
21
- `pred_original_sample` can be used to preview progress or for guidance.
22
- """
23
-
24
- prev_sample: torch.FloatTensor
25
- pred_original_sample: Optional[torch.FloatTensor] = None
26
-
27
- class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
28
- def set_noise_list(self, noise_list):
29
- self.noise_list = noise_list
30
-
31
- def get_noise_to_remove(self):
32
- sigma_from = self.sigmas[self.step_index]
33
- sigma_to = self.sigmas[self.step_index + 1]
34
- sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
35
-
36
- return self.noise_list[self.step_index] * sigma_up\
37
-
38
- def scale_model_input(
39
- self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
40
- ) -> torch.FloatTensor:
41
- """
42
- Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
43
- current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
44
-
45
- Args:
46
- sample (`torch.FloatTensor`):
47
- The input sample.
48
- timestep (`int`, *optional*):
49
- The current timestep in the diffusion chain.
50
-
51
- Returns:
52
- `torch.FloatTensor`:
53
- A scaled input sample.
54
- """
55
-
56
- self._init_step_index(timestep.view((1)))
57
- return EulerAncestralDiscreteScheduler.scale_model_input(self, sample, timestep)
58
-
59
-
60
- def step(
61
- self,
62
- model_output: torch.FloatTensor,
63
- timestep: Union[float, torch.FloatTensor],
64
- sample: torch.FloatTensor,
65
- generator: Optional[torch.Generator] = None,
66
- return_dict: bool = True,
67
- ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
68
- """
69
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
70
- process from the learned model outputs (most often the predicted noise).
71
-
72
- Args:
73
- model_output (`torch.FloatTensor`):
74
- The direct output from learned diffusion model.
75
- timestep (`float`):
76
- The current discrete timestep in the diffusion chain.
77
- sample (`torch.FloatTensor`):
78
- A current instance of a sample created by the diffusion process.
79
- generator (`torch.Generator`, *optional*):
80
- A random number generator.
81
- return_dict (`bool`):
82
- Whether or not to return a
83
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
84
-
85
- Returns:
86
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
87
- If return_dict is `True`,
88
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
89
- otherwise a tuple is returned where the first element is the sample tensor.
90
-
91
- """
92
-
93
- if (
94
- isinstance(timestep, int)
95
- or isinstance(timestep, torch.IntTensor)
96
- or isinstance(timestep, torch.LongTensor)
97
- ):
98
- raise ValueError(
99
- (
100
- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
101
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
102
- " one of the `scheduler.timesteps` as a timestep."
103
- ),
104
- )
105
-
106
- if not self.is_scale_input_called:
107
- logger.warning(
108
- "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
109
- "See `StableDiffusionPipeline` for a usage example."
110
- )
111
-
112
- self._init_step_index(timestep.view((1)))
113
-
114
- sigma = self.sigmas[self.step_index]
115
-
116
- # Upcast to avoid precision issues when computing prev_sample
117
- sample = sample.to(torch.float32)
118
-
119
- # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
120
- if self.config.prediction_type == "epsilon":
121
- pred_original_sample = sample - sigma * model_output
122
- elif self.config.prediction_type == "v_prediction":
123
- # * c_out + input * c_skip
124
- pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
125
- elif self.config.prediction_type == "sample":
126
- raise NotImplementedError("prediction_type not implemented yet: sample")
127
- else:
128
- raise ValueError(
129
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
130
- )
131
-
132
- sigma_from = self.sigmas[self.step_index]
133
- sigma_to = self.sigmas[self.step_index + 1]
134
- sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
135
- sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
136
-
137
- # 2. Convert to an ODE derivative
138
- # derivative = (sample - pred_original_sample) / sigma
139
- derivative = model_output
140
-
141
- dt = sigma_down - sigma
142
-
143
- prev_sample = sample + derivative * dt
144
-
145
- device = model_output.device
146
- # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
147
- # prev_sample = prev_sample + noise * sigma_up
148
-
149
- prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
150
-
151
- # Cast sample back to model compatible dtype
152
- prev_sample = prev_sample.to(model_output.dtype)
153
-
154
- # upon completion increase step index by one
155
- self._step_index += 1
156
-
157
- if not return_dict:
158
- return (prev_sample,)
159
-
160
- return EulerAncestralDiscreteSchedulerOutput(
161
- prev_sample=prev_sample, pred_original_sample=pred_original_sample
162
- )
163
-
164
- def step_and_update_noise(
165
- self,
166
- model_output: torch.FloatTensor,
167
- timestep: Union[float, torch.FloatTensor],
168
- sample: torch.FloatTensor,
169
- expected_prev_sample: torch.FloatTensor,
170
- update_epsilon_type=Epsilon_Update_Type.OVERRIDE,
171
- generator: Optional[torch.Generator] = None,
172
- return_dict: bool = True,
173
- ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
174
- """
175
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
176
- process from the learned model outputs (most often the predicted noise).
177
-
178
- Args:
179
- model_output (`torch.FloatTensor`):
180
- The direct output from learned diffusion model.
181
- timestep (`float`):
182
- The current discrete timestep in the diffusion chain.
183
- sample (`torch.FloatTensor`):
184
- A current instance of a sample created by the diffusion process.
185
- generator (`torch.Generator`, *optional*):
186
- A random number generator.
187
- return_dict (`bool`):
188
- Whether or not to return a
189
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
190
-
191
- Returns:
192
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
193
- If return_dict is `True`,
194
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
195
- otherwise a tuple is returned where the first element is the sample tensor.
196
-
197
- """
198
-
199
- if (
200
- isinstance(timestep, int)
201
- or isinstance(timestep, torch.IntTensor)
202
- or isinstance(timestep, torch.LongTensor)
203
- ):
204
- raise ValueError(
205
- (
206
- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
207
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
208
- " one of the `scheduler.timesteps` as a timestep."
209
- ),
210
- )
211
-
212
- if not self.is_scale_input_called:
213
- logger.warning(
214
- "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
215
- "See `StableDiffusionPipeline` for a usage example."
216
- )
217
-
218
- self._init_step_index(timestep.view((1)))
219
-
220
- sigma = self.sigmas[self.step_index]
221
-
222
- # Upcast to avoid precision issues when computing prev_sample
223
- sample = sample.to(torch.float32)
224
-
225
- # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
226
- if self.config.prediction_type == "epsilon":
227
- pred_original_sample = sample - sigma * model_output
228
- elif self.config.prediction_type == "v_prediction":
229
- # * c_out + input * c_skip
230
- pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
231
- elif self.config.prediction_type == "sample":
232
- raise NotImplementedError("prediction_type not implemented yet: sample")
233
- else:
234
- raise ValueError(
235
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
236
- )
237
-
238
- sigma_from = self.sigmas[self.step_index]
239
- sigma_to = self.sigmas[self.step_index + 1]
240
- sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
241
- sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
242
-
243
- # 2. Convert to an ODE derivative
244
- # derivative = (sample - pred_original_sample) / sigma
245
- derivative = model_output
246
-
247
- dt = sigma_down - sigma
248
-
249
- prev_sample = sample + derivative * dt
250
-
251
- device = model_output.device
252
- # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
253
- # prev_sample = prev_sample + noise * sigma_up
254
-
255
- if sigma_up > 0:
256
- req_noise = (expected_prev_sample - prev_sample) / sigma_up
257
- if update_epsilon_type == Epsilon_Update_Type.OVERRIDE:
258
- self.noise_list[self.step_index] = req_noise
259
- else:
260
- for i in range(10):
261
- n = torch.autograd.Variable(self.noise_list[self.step_index].detach().clone(), requires_grad=True)
262
- loss = torch.norm(n - req_noise.detach())
263
- loss.backward()
264
- self.noise_list[self.step_index] -= n.grad.detach() * 1.8
265
-
266
-
267
- prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
268
-
269
- # Cast sample back to model compatible dtype
270
- prev_sample = prev_sample.to(model_output.dtype)
271
-
272
- # upon completion increase step index by one
273
- self._step_index += 1
274
-
275
- if not return_dict:
276
- return (prev_sample,)
277
-
278
- return EulerAncestralDiscreteSchedulerOutput(
279
- prev_sample=prev_sample, pred_original_sample=pred_original_sample
280
- )
281
-
282
- def inv_step(
283
- self,
284
- model_output: torch.FloatTensor,
285
- timestep: Union[float, torch.FloatTensor],
286
- sample: torch.FloatTensor,
287
- generator: Optional[torch.Generator] = None,
288
- return_dict: bool = True,
289
- ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
290
- """
291
- Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
292
- process from the learned model outputs (most often the predicted noise).
293
-
294
- Args:
295
- model_output (`torch.FloatTensor`):
296
- The direct output from learned diffusion model.
297
- timestep (`float`):
298
- The current discrete timestep in the diffusion chain.
299
- sample (`torch.FloatTensor`):
300
- A current instance of a sample created by the diffusion process.
301
- generator (`torch.Generator`, *optional*):
302
- A random number generator.
303
- return_dict (`bool`):
304
- Whether or not to return a
305
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
306
-
307
- Returns:
308
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
309
- If return_dict is `True`,
310
- [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
311
- otherwise a tuple is returned where the first element is the sample tensor.
312
-
313
- """
314
-
315
- if (
316
- isinstance(timestep, int)
317
- or isinstance(timestep, torch.IntTensor)
318
- or isinstance(timestep, torch.LongTensor)
319
- ):
320
- raise ValueError(
321
- (
322
- "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
323
- " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
324
- " one of the `scheduler.timesteps` as a timestep."
325
- ),
326
- )
327
-
328
- if not self.is_scale_input_called:
329
- logger.warning(
330
- "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
331
- "See `StableDiffusionPipeline` for a usage example."
332
- )
333
-
334
- self._init_step_index(timestep.view((1)))
335
-
336
- sigma = self.sigmas[self.step_index]
337
-
338
- # Upcast to avoid precision issues when computing prev_sample
339
- sample = sample.to(torch.float32)
340
-
341
- # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
342
- if self.config.prediction_type == "epsilon":
343
- pred_original_sample = sample - sigma * model_output
344
- elif self.config.prediction_type == "v_prediction":
345
- # * c_out + input * c_skip
346
- pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
347
- elif self.config.prediction_type == "sample":
348
- raise NotImplementedError("prediction_type not implemented yet: sample")
349
- else:
350
- raise ValueError(
351
- f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
352
- )
353
-
354
- sigma_from = self.sigmas[self.step_index]
355
- sigma_to = self.sigmas[self.step_index+1]
356
- # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
357
- sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
358
- # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
359
- sigma_down = sigma_to**2 / sigma_from
360
-
361
- # 2. Convert to an ODE derivative
362
- # derivative = (sample - pred_original_sample) / sigma
363
- derivative = model_output
364
-
365
- dt = sigma_down - sigma
366
- # dt = sigma_down - sigma_from
367
-
368
- prev_sample = sample - derivative * dt
369
-
370
- device = model_output.device
371
- # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
372
- # prev_sample = prev_sample + noise * sigma_up
373
-
374
- prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
375
-
376
- # Cast sample back to model compatible dtype
377
- prev_sample = prev_sample.to(model_output.dtype)
378
-
379
- # upon completion increase step index by one
380
- self._step_index += 1
381
-
382
- if not return_dict:
383
- return (prev_sample,)
384
-
385
- return EulerAncestralDiscreteSchedulerOutput(
386
- prev_sample=prev_sample, pred_original_sample=pred_original_sample
387
- )
388
-
389
- def get_all_sigmas(self) -> torch.FloatTensor:
390
- sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
391
- sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
392
- return torch.from_numpy(sigmas)
393
-
394
- def add_noise_off_schedule(
395
- self,
396
- original_samples: torch.FloatTensor,
397
- noise: torch.FloatTensor,
398
- timesteps: torch.FloatTensor,
399
- ) -> torch.FloatTensor:
400
- # Make sure sigmas and timesteps have the same device and dtype as original_samples
401
- sigmas = self.get_all_sigmas()
402
- sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
403
- if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
404
- # mps does not support float64
405
- timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
406
- else:
407
- timesteps = timesteps.to(original_samples.device)
408
-
409
- step_indices = 1000 - int(timesteps.item())
410
-
411
- sigma = sigmas[step_indices].flatten()
412
- while len(sigma.shape) < len(original_samples.shape):
413
- sigma = sigma.unsqueeze(-1)
414
-
415
- noisy_samples = original_samples + noise * sigma
416
- return noisy_samples
417
-
418
- # def update_noise_for_friendly_inversion(
419
- # self,
420
- # model_output: torch.FloatTensor,
421
- # timestep: Union[float, torch.FloatTensor],
422
- # z_t: torch.FloatTensor,
423
- # z_tp1: torch.FloatTensor,
424
- # return_dict: bool = True,
425
- # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
426
- # if (
427
- # isinstance(timestep, int)
428
- # or isinstance(timestep, torch.IntTensor)
429
- # or isinstance(timestep, torch.LongTensor)
430
- # ):
431
- # raise ValueError(
432
- # (
433
- # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
434
- # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
435
- # " one of the `scheduler.timesteps` as a timestep."
436
- # ),
437
- # )
438
-
439
- # if not self.is_scale_input_called:
440
- # logger.warning(
441
- # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
442
- # "See `StableDiffusionPipeline` for a usage example."
443
- # )
444
-
445
- # self._init_step_index(timestep.view((1)))
446
-
447
- # sigma = self.sigmas[self.step_index]
448
-
449
- # sigma_from = self.sigmas[self.step_index]
450
- # sigma_to = self.sigmas[self.step_index+1]
451
- # # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
452
- # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
453
- # # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
454
- # sigma_down = sigma_to**2 / sigma_from
455
-
456
- # # 2. Conv = (sample - pred_original_sample) / sigma
457
- # derivative = model_output
458
-
459
- # dt = sigma_down - sigma
460
- # # dt = sigma_down - sigma_from
461
-
462
- # prev_sample = z_t - derivative * dt
463
-
464
- # if sigma_up > 0:
465
- # self.noise_list[self.step_index] = (prev_sample - z_tp1) / sigma_up
466
-
467
- # prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
468
-
469
-
470
- # if not return_dict:
471
- # return (prev_sample,)
472
-
473
- # return EulerAncestralDiscreteSchedulerOutput(
474
- # prev_sample=prev_sample, pred_original_sample=None
475
- # )
476
-
477
-
478
- # def step_friendly_inversion(
479
- # self,
480
- # model_output: torch.FloatTensor,
481
- # timestep: Union[float, torch.FloatTensor],
482
- # sample: torch.FloatTensor,
483
- # generator: Optional[torch.Generator] = None,
484
- # return_dict: bool = True,
485
- # expected_next_sample: torch.FloatTensor = None,
486
- # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
487
- # """
488
- # Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
489
- # process from the learned model outputs (most often the predicted noise).
490
-
491
- # Args:
492
- # model_output (`torch.FloatTensor`):
493
- # The direct output from learned diffusion model.
494
- # timestep (`float`):
495
- # The current discrete timestep in the diffusion chain.
496
- # sample (`torch.FloatTensor`):
497
- # A current instance of a sample created by the diffusion process.
498
- # generator (`torch.Generator`, *optional*):
499
- # A random number generator.
500
- # return_dict (`bool`):
501
- # Whether or not to return a
502
- # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
503
-
504
- # Returns:
505
- # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
506
- # If return_dict is `True`,
507
- # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
508
- # otherwise a tuple is returned where the first element is the sample tensor.
509
-
510
- # """
511
-
512
- # if (
513
- # isinstance(timestep, int)
514
- # or isinstance(timestep, torch.IntTensor)
515
- # or isinstance(timestep, torch.LongTensor)
516
- # ):
517
- # raise ValueError(
518
- # (
519
- # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
520
- # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
521
- # " one of the `scheduler.timesteps` as a timestep."
522
- # ),
523
- # )
524
-
525
- # if not self.is_scale_input_called:
526
- # logger.warning(
527
- # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
528
- # "See `StableDiffusionPipeline` for a usage example."
529
- # )
530
-
531
- # self._init_step_index(timestep.view((1)))
532
-
533
- # sigma = self.sigmas[self.step_index]
534
-
535
- # # Upcast to avoid precision issues when computing prev_sample
536
- # sample = sample.to(torch.float32)
537
-
538
- # # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
539
- # if self.config.prediction_type == "epsilon":
540
- # pred_original_sample = sample - sigma * model_output
541
- # elif self.config.prediction_type == "v_prediction":
542
- # # * c_out + input * c_skip
543
- # pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
544
- # elif self.config.prediction_type == "sample":
545
- # raise NotImplementedError("prediction_type not implemented yet: sample")
546
- # else:
547
- # raise ValueError(
548
- # f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
549
- # )
550
-
551
- # sigma_from = self.sigmas[self.step_index]
552
- # sigma_to = self.sigmas[self.step_index + 1]
553
- # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
554
- # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
555
-
556
- # # 2. Convert to an ODE derivative
557
- # # derivative = (sample - pred_original_sample) / sigma
558
- # derivative = model_output
559
-
560
- # dt = sigma_down - sigma
561
-
562
- # prev_sample = sample + derivative * dt
563
-
564
- # device = model_output.device
565
- # # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
566
- # # prev_sample = prev_sample + noise * sigma_up
567
-
568
- # if sigma_up > 0:
569
- # self.noise_list[self.step_index] = (expected_next_sample - prev_sample) / sigma_up
570
-
571
- # prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
572
-
573
- # # Cast sample back to model compatible dtype
574
- # prev_sample = prev_sample.to(model_output.dtype)
575
-
576
- # # upon completion increase step index by one
577
- # self._step_index += 1
578
-
579
- # if not return_dict:
580
- # return (prev_sample,)
581
-
582
- # return EulerAncestralDiscreteSchedulerOutput(
583
- # prev_sample=prev_sample, pred_original_sample=pred_original_sample
 
 
 
 
 
 
584
  # )
 
1
+ # Code is based on ReNoise https://github.com/garibida/ReNoise-Inversion
2
+
3
+ from diffusers import EulerAncestralDiscreteScheduler
4
+ from diffusers.utils import BaseOutput
5
+ import torch
6
+ from typing import List, Optional, Tuple, Union
7
+ import numpy as np
8
+
9
+ from src.eunms import Epsilon_Update_Type
10
+
11
+ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
12
+ """
13
+ Output class for the scheduler's `step` function output.
14
+
15
+ Args:
16
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
17
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
18
+ denoising loop.
19
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
20
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
21
+ `pred_original_sample` can be used to preview progress or for guidance.
22
+ """
23
+
24
+ prev_sample: torch.FloatTensor
25
+ pred_original_sample: Optional[torch.FloatTensor] = None
26
+
27
+ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
28
+ def set_noise_list(self, noise_list):
29
+ self.noise_list = noise_list
30
+
31
+ def set_noise_list_device(self, device):
32
+ if self.noise_list[0].device == device:
33
+ return
34
+ for i in range(len(self.noise_list)):
35
+ self.noise_list[i] = self.noise_list[i].to(device)
36
+
37
+ def get_noise_to_remove(self):
38
+ sigma_from = self.sigmas[self.step_index]
39
+ sigma_to = self.sigmas[self.step_index + 1]
40
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
41
+
42
+ return self.noise_list[self.step_index] * sigma_up\
43
+
44
+ def scale_model_input(
45
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
46
+ ) -> torch.FloatTensor:
47
+ """
48
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
49
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
50
+
51
+ Args:
52
+ sample (`torch.FloatTensor`):
53
+ The input sample.
54
+ timestep (`int`, *optional*):
55
+ The current timestep in the diffusion chain.
56
+
57
+ Returns:
58
+ `torch.FloatTensor`:
59
+ A scaled input sample.
60
+ """
61
+
62
+ self._init_step_index(timestep.view((1)))
63
+ return EulerAncestralDiscreteScheduler.scale_model_input(self, sample, timestep)
64
+
65
+
66
+ def step(
67
+ self,
68
+ model_output: torch.FloatTensor,
69
+ timestep: Union[float, torch.FloatTensor],
70
+ sample: torch.FloatTensor,
71
+ generator: Optional[torch.Generator] = None,
72
+ return_dict: bool = True,
73
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
74
+ """
75
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
76
+ process from the learned model outputs (most often the predicted noise).
77
+
78
+ Args:
79
+ model_output (`torch.FloatTensor`):
80
+ The direct output from learned diffusion model.
81
+ timestep (`float`):
82
+ The current discrete timestep in the diffusion chain.
83
+ sample (`torch.FloatTensor`):
84
+ A current instance of a sample created by the diffusion process.
85
+ generator (`torch.Generator`, *optional*):
86
+ A random number generator.
87
+ return_dict (`bool`):
88
+ Whether or not to return a
89
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
90
+
91
+ Returns:
92
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
93
+ If return_dict is `True`,
94
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
95
+ otherwise a tuple is returned where the first element is the sample tensor.
96
+
97
+ """
98
+
99
+ if (
100
+ isinstance(timestep, int)
101
+ or isinstance(timestep, torch.IntTensor)
102
+ or isinstance(timestep, torch.LongTensor)
103
+ ):
104
+ raise ValueError(
105
+ (
106
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
107
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
108
+ " one of the `scheduler.timesteps` as a timestep."
109
+ ),
110
+ )
111
+
112
+ if not self.is_scale_input_called:
113
+ logger.warning(
114
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
115
+ "See `StableDiffusionPipeline` for a usage example."
116
+ )
117
+
118
+ self._init_step_index(timestep.view((1)))
119
+
120
+ sigma = self.sigmas[self.step_index]
121
+
122
+ # Upcast to avoid precision issues when computing prev_sample
123
+ sample = sample.to(torch.float32)
124
+
125
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
126
+ if self.config.prediction_type == "epsilon":
127
+ pred_original_sample = sample - sigma * model_output
128
+ elif self.config.prediction_type == "v_prediction":
129
+ # * c_out + input * c_skip
130
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
131
+ elif self.config.prediction_type == "sample":
132
+ raise NotImplementedError("prediction_type not implemented yet: sample")
133
+ else:
134
+ raise ValueError(
135
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
136
+ )
137
+
138
+ sigma_from = self.sigmas[self.step_index]
139
+ sigma_to = self.sigmas[self.step_index + 1]
140
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
141
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
142
+
143
+ # 2. Convert to an ODE derivative
144
+ # derivative = (sample - pred_original_sample) / sigma
145
+ derivative = model_output
146
+
147
+ dt = sigma_down - sigma
148
+
149
+ prev_sample = sample + derivative * dt
150
+
151
+ device = model_output.device
152
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
153
+ # prev_sample = prev_sample + noise * sigma_up
154
+
155
+ prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
156
+
157
+ # Cast sample back to model compatible dtype
158
+ prev_sample = prev_sample.to(model_output.dtype)
159
+
160
+ # upon completion increase step index by one
161
+ self._step_index += 1
162
+
163
+ if not return_dict:
164
+ return (prev_sample,)
165
+
166
+ return EulerAncestralDiscreteSchedulerOutput(
167
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
168
+ )
169
+
170
+ def step_and_update_noise(
171
+ self,
172
+ model_output: torch.FloatTensor,
173
+ timestep: Union[float, torch.FloatTensor],
174
+ sample: torch.FloatTensor,
175
+ expected_prev_sample: torch.FloatTensor,
176
+ update_epsilon_type=Epsilon_Update_Type.OVERRIDE,
177
+ generator: Optional[torch.Generator] = None,
178
+ return_dict: bool = True,
179
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
180
+ """
181
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
182
+ process from the learned model outputs (most often the predicted noise).
183
+
184
+ Args:
185
+ model_output (`torch.FloatTensor`):
186
+ The direct output from learned diffusion model.
187
+ timestep (`float`):
188
+ The current discrete timestep in the diffusion chain.
189
+ sample (`torch.FloatTensor`):
190
+ A current instance of a sample created by the diffusion process.
191
+ generator (`torch.Generator`, *optional*):
192
+ A random number generator.
193
+ return_dict (`bool`):
194
+ Whether or not to return a
195
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
196
+
197
+ Returns:
198
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
199
+ If return_dict is `True`,
200
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
201
+ otherwise a tuple is returned where the first element is the sample tensor.
202
+
203
+ """
204
+
205
+ if (
206
+ isinstance(timestep, int)
207
+ or isinstance(timestep, torch.IntTensor)
208
+ or isinstance(timestep, torch.LongTensor)
209
+ ):
210
+ raise ValueError(
211
+ (
212
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
213
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
214
+ " one of the `scheduler.timesteps` as a timestep."
215
+ ),
216
+ )
217
+
218
+ if not self.is_scale_input_called:
219
+ logger.warning(
220
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
221
+ "See `StableDiffusionPipeline` for a usage example."
222
+ )
223
+
224
+ self._init_step_index(timestep.view((1)))
225
+
226
+ sigma = self.sigmas[self.step_index]
227
+
228
+ # Upcast to avoid precision issues when computing prev_sample
229
+ sample = sample.to(torch.float32)
230
+
231
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
232
+ if self.config.prediction_type == "epsilon":
233
+ pred_original_sample = sample - sigma * model_output
234
+ elif self.config.prediction_type == "v_prediction":
235
+ # * c_out + input * c_skip
236
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
237
+ elif self.config.prediction_type == "sample":
238
+ raise NotImplementedError("prediction_type not implemented yet: sample")
239
+ else:
240
+ raise ValueError(
241
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
242
+ )
243
+
244
+ sigma_from = self.sigmas[self.step_index]
245
+ sigma_to = self.sigmas[self.step_index + 1]
246
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
247
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
248
+
249
+ # 2. Convert to an ODE derivative
250
+ # derivative = (sample - pred_original_sample) / sigma
251
+ derivative = model_output
252
+
253
+ dt = sigma_down - sigma
254
+
255
+ prev_sample = sample + derivative * dt
256
+
257
+ device = model_output.device
258
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
259
+ # prev_sample = prev_sample + noise * sigma_up
260
+
261
+ if sigma_up > 0:
262
+ req_noise = (expected_prev_sample - prev_sample) / sigma_up
263
+ if update_epsilon_type == Epsilon_Update_Type.OVERRIDE:
264
+ self.noise_list[self.step_index] = req_noise
265
+ else:
266
+ for i in range(10):
267
+ n = torch.autograd.Variable(self.noise_list[self.step_index].detach().clone(), requires_grad=True)
268
+ loss = torch.norm(n - req_noise.detach())
269
+ loss.backward()
270
+ self.noise_list[self.step_index] -= n.grad.detach() * 1.8
271
+
272
+
273
+ prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
274
+
275
+ # Cast sample back to model compatible dtype
276
+ prev_sample = prev_sample.to(model_output.dtype)
277
+
278
+ # upon completion increase step index by one
279
+ self._step_index += 1
280
+
281
+ if not return_dict:
282
+ return (prev_sample,)
283
+
284
+ return EulerAncestralDiscreteSchedulerOutput(
285
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
286
+ )
287
+
288
+ def inv_step(
289
+ self,
290
+ model_output: torch.FloatTensor,
291
+ timestep: Union[float, torch.FloatTensor],
292
+ sample: torch.FloatTensor,
293
+ generator: Optional[torch.Generator] = None,
294
+ return_dict: bool = True,
295
+ ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
296
+ """
297
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
298
+ process from the learned model outputs (most often the predicted noise).
299
+
300
+ Args:
301
+ model_output (`torch.FloatTensor`):
302
+ The direct output from learned diffusion model.
303
+ timestep (`float`):
304
+ The current discrete timestep in the diffusion chain.
305
+ sample (`torch.FloatTensor`):
306
+ A current instance of a sample created by the diffusion process.
307
+ generator (`torch.Generator`, *optional*):
308
+ A random number generator.
309
+ return_dict (`bool`):
310
+ Whether or not to return a
311
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
312
+
313
+ Returns:
314
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
315
+ If return_dict is `True`,
316
+ [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
317
+ otherwise a tuple is returned where the first element is the sample tensor.
318
+
319
+ """
320
+
321
+ if (
322
+ isinstance(timestep, int)
323
+ or isinstance(timestep, torch.IntTensor)
324
+ or isinstance(timestep, torch.LongTensor)
325
+ ):
326
+ raise ValueError(
327
+ (
328
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
329
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
330
+ " one of the `scheduler.timesteps` as a timestep."
331
+ ),
332
+ )
333
+
334
+ if not self.is_scale_input_called:
335
+ logger.warning(
336
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
337
+ "See `StableDiffusionPipeline` for a usage example."
338
+ )
339
+
340
+ self._init_step_index(timestep.view((1)))
341
+
342
+ sigma = self.sigmas[self.step_index]
343
+
344
+ # Upcast to avoid precision issues when computing prev_sample
345
+ sample = sample.to(torch.float32)
346
+
347
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
348
+ if self.config.prediction_type == "epsilon":
349
+ pred_original_sample = sample - sigma * model_output
350
+ elif self.config.prediction_type == "v_prediction":
351
+ # * c_out + input * c_skip
352
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
353
+ elif self.config.prediction_type == "sample":
354
+ raise NotImplementedError("prediction_type not implemented yet: sample")
355
+ else:
356
+ raise ValueError(
357
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
358
+ )
359
+
360
+ sigma_from = self.sigmas[self.step_index]
361
+ sigma_to = self.sigmas[self.step_index+1]
362
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
363
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
364
+ # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
365
+ sigma_down = sigma_to**2 / sigma_from
366
+
367
+ # 2. Convert to an ODE derivative
368
+ # derivative = (sample - pred_original_sample) / sigma
369
+ derivative = model_output
370
+
371
+ dt = sigma_down - sigma
372
+ # dt = sigma_down - sigma_from
373
+
374
+ prev_sample = sample - derivative * dt
375
+
376
+ device = model_output.device
377
+ # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
378
+ # prev_sample = prev_sample + noise * sigma_up
379
+
380
+ prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
381
+
382
+ # Cast sample back to model compatible dtype
383
+ prev_sample = prev_sample.to(model_output.dtype)
384
+
385
+ # upon completion increase step index by one
386
+ self._step_index += 1
387
+
388
+ if not return_dict:
389
+ return (prev_sample,)
390
+
391
+ return EulerAncestralDiscreteSchedulerOutput(
392
+ prev_sample=prev_sample, pred_original_sample=pred_original_sample
393
+ )
394
+
395
+ def get_all_sigmas(self) -> torch.FloatTensor:
396
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
397
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
398
+ return torch.from_numpy(sigmas)
399
+
400
+ def add_noise_off_schedule(
401
+ self,
402
+ original_samples: torch.FloatTensor,
403
+ noise: torch.FloatTensor,
404
+ timesteps: torch.FloatTensor,
405
+ ) -> torch.FloatTensor:
406
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
407
+ sigmas = self.get_all_sigmas()
408
+ sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
409
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
410
+ # mps does not support float64
411
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
412
+ else:
413
+ timesteps = timesteps.to(original_samples.device)
414
+
415
+ step_indices = 1000 - int(timesteps.item())
416
+
417
+ sigma = sigmas[step_indices].flatten()
418
+ while len(sigma.shape) < len(original_samples.shape):
419
+ sigma = sigma.unsqueeze(-1)
420
+
421
+ noisy_samples = original_samples + noise * sigma
422
+ return noisy_samples
423
+
424
+ # def update_noise_for_friendly_inversion(
425
+ # self,
426
+ # model_output: torch.FloatTensor,
427
+ # timestep: Union[float, torch.FloatTensor],
428
+ # z_t: torch.FloatTensor,
429
+ # z_tp1: torch.FloatTensor,
430
+ # return_dict: bool = True,
431
+ # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
432
+ # if (
433
+ # isinstance(timestep, int)
434
+ # or isinstance(timestep, torch.IntTensor)
435
+ # or isinstance(timestep, torch.LongTensor)
436
+ # ):
437
+ # raise ValueError(
438
+ # (
439
+ # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
440
+ # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
441
+ # " one of the `scheduler.timesteps` as a timestep."
442
+ # ),
443
+ # )
444
+
445
+ # if not self.is_scale_input_called:
446
+ # logger.warning(
447
+ # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
448
+ # "See `StableDiffusionPipeline` for a usage example."
449
+ # )
450
+
451
+ # self._init_step_index(timestep.view((1)))
452
+
453
+ # sigma = self.sigmas[self.step_index]
454
+
455
+ # sigma_from = self.sigmas[self.step_index]
456
+ # sigma_to = self.sigmas[self.step_index+1]
457
+ # # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
458
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2).abs() / sigma_from**2) ** 0.5
459
+ # # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
460
+ # sigma_down = sigma_to**2 / sigma_from
461
+
462
+ # # 2. Conv = (sample - pred_original_sample) / sigma
463
+ # derivative = model_output
464
+
465
+ # dt = sigma_down - sigma
466
+ # # dt = sigma_down - sigma_from
467
+
468
+ # prev_sample = z_t - derivative * dt
469
+
470
+ # if sigma_up > 0:
471
+ # self.noise_list[self.step_index] = (prev_sample - z_tp1) / sigma_up
472
+
473
+ # prev_sample = prev_sample - self.noise_list[self.step_index] * sigma_up
474
+
475
+
476
+ # if not return_dict:
477
+ # return (prev_sample,)
478
+
479
+ # return EulerAncestralDiscreteSchedulerOutput(
480
+ # prev_sample=prev_sample, pred_original_sample=None
481
+ # )
482
+
483
+
484
+ # def step_friendly_inversion(
485
+ # self,
486
+ # model_output: torch.FloatTensor,
487
+ # timestep: Union[float, torch.FloatTensor],
488
+ # sample: torch.FloatTensor,
489
+ # generator: Optional[torch.Generator] = None,
490
+ # return_dict: bool = True,
491
+ # expected_next_sample: torch.FloatTensor = None,
492
+ # ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
493
+ # """
494
+ # Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
495
+ # process from the learned model outputs (most often the predicted noise).
496
+
497
+ # Args:
498
+ # model_output (`torch.FloatTensor`):
499
+ # The direct output from learned diffusion model.
500
+ # timestep (`float`):
501
+ # The current discrete timestep in the diffusion chain.
502
+ # sample (`torch.FloatTensor`):
503
+ # A current instance of a sample created by the diffusion process.
504
+ # generator (`torch.Generator`, *optional*):
505
+ # A random number generator.
506
+ # return_dict (`bool`):
507
+ # Whether or not to return a
508
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
509
+
510
+ # Returns:
511
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
512
+ # If return_dict is `True`,
513
+ # [`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
514
+ # otherwise a tuple is returned where the first element is the sample tensor.
515
+
516
+ # """
517
+
518
+ # if (
519
+ # isinstance(timestep, int)
520
+ # or isinstance(timestep, torch.IntTensor)
521
+ # or isinstance(timestep, torch.LongTensor)
522
+ # ):
523
+ # raise ValueError(
524
+ # (
525
+ # "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
526
+ # " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
527
+ # " one of the `scheduler.timesteps` as a timestep."
528
+ # ),
529
+ # )
530
+
531
+ # if not self.is_scale_input_called:
532
+ # logger.warning(
533
+ # "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
534
+ # "See `StableDiffusionPipeline` for a usage example."
535
+ # )
536
+
537
+ # self._init_step_index(timestep.view((1)))
538
+
539
+ # sigma = self.sigmas[self.step_index]
540
+
541
+ # # Upcast to avoid precision issues when computing prev_sample
542
+ # sample = sample.to(torch.float32)
543
+
544
+ # # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
545
+ # if self.config.prediction_type == "epsilon":
546
+ # pred_original_sample = sample - sigma * model_output
547
+ # elif self.config.prediction_type == "v_prediction":
548
+ # # * c_out + input * c_skip
549
+ # pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
550
+ # elif self.config.prediction_type == "sample":
551
+ # raise NotImplementedError("prediction_type not implemented yet: sample")
552
+ # else:
553
+ # raise ValueError(
554
+ # f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
555
+ # )
556
+
557
+ # sigma_from = self.sigmas[self.step_index]
558
+ # sigma_to = self.sigmas[self.step_index + 1]
559
+ # sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
560
+ # sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
561
+
562
+ # # 2. Convert to an ODE derivative
563
+ # # derivative = (sample - pred_original_sample) / sigma
564
+ # derivative = model_output
565
+
566
+ # dt = sigma_down - sigma
567
+
568
+ # prev_sample = sample + derivative * dt
569
+
570
+ # device = model_output.device
571
+ # # noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)
572
+ # # prev_sample = prev_sample + noise * sigma_up
573
+
574
+ # if sigma_up > 0:
575
+ # self.noise_list[self.step_index] = (expected_next_sample - prev_sample) / sigma_up
576
+
577
+ # prev_sample = prev_sample + self.noise_list[self.step_index] * sigma_up
578
+
579
+ # # Cast sample back to model compatible dtype
580
+ # prev_sample = prev_sample.to(model_output.dtype)
581
+
582
+ # # upon completion increase step index by one
583
+ # self._step_index += 1
584
+
585
+ # if not return_dict:
586
+ # return (prev_sample,)
587
+
588
+ # return EulerAncestralDiscreteSchedulerOutput(
589
+ # prev_sample=prev_sample, pred_original_sample=pred_original_sample
590
  # )