davidvgilmore commited on
Commit
ac146c5
·
verified ·
1 Parent(s): 1b41729

Upload hy3dgen/shapegen/schedulers.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hy3dgen/shapegen/schedulers.py +307 -0
hy3dgen/shapegen/schedulers.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+
25
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
+
27
+
28
+ @dataclass
29
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
30
+ """
31
+ Output class for the scheduler's `step` function output.
32
+
33
+ Args:
34
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
35
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
36
+ denoising loop.
37
+ """
38
+
39
+ prev_sample: torch.FloatTensor
40
+
41
+
42
+ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
43
+ """
44
+ NOTE: this is very similar to diffusers.FlowMatchEulerDiscreteScheduler. Except our timesteps are reversed
45
+
46
+ Euler scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 1
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ shift: float = 1.0,
69
+ use_dynamic_shifting=False,
70
+ ):
71
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32).copy()
72
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
73
+
74
+ sigmas = timesteps / num_train_timesteps
75
+ if not use_dynamic_shifting:
76
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
77
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
78
+
79
+ self.timesteps = sigmas * num_train_timesteps
80
+
81
+ self._step_index = None
82
+ self._begin_index = None
83
+
84
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
85
+ self.sigma_min = self.sigmas[-1].item()
86
+ self.sigma_max = self.sigmas[0].item()
87
+
88
+ @property
89
+ def step_index(self):
90
+ """
91
+ The index counter for current timestep. It will increase 1 after each scheduler step.
92
+ """
93
+ return self._step_index
94
+
95
+ @property
96
+ def begin_index(self):
97
+ """
98
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
99
+ """
100
+ return self._begin_index
101
+
102
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
103
+ def set_begin_index(self, begin_index: int = 0):
104
+ """
105
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
106
+
107
+ Args:
108
+ begin_index (`int`):
109
+ The begin index for the scheduler.
110
+ """
111
+ self._begin_index = begin_index
112
+
113
+ def scale_noise(
114
+ self,
115
+ sample: torch.FloatTensor,
116
+ timestep: Union[float, torch.FloatTensor],
117
+ noise: Optional[torch.FloatTensor] = None,
118
+ ) -> torch.FloatTensor:
119
+ """
120
+ Forward process in flow-matching
121
+
122
+ Args:
123
+ sample (`torch.FloatTensor`):
124
+ The input sample.
125
+ timestep (`int`, *optional*):
126
+ The current timestep in the diffusion chain.
127
+
128
+ Returns:
129
+ `torch.FloatTensor`:
130
+ A scaled input sample.
131
+ """
132
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
133
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
134
+
135
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
136
+ # mps does not support float64
137
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
138
+ timestep = timestep.to(sample.device, dtype=torch.float32)
139
+ else:
140
+ schedule_timesteps = self.timesteps.to(sample.device)
141
+ timestep = timestep.to(sample.device)
142
+
143
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
144
+ if self.begin_index is None:
145
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
146
+ elif self.step_index is not None:
147
+ # add_noise is called after first denoising step (for inpainting)
148
+ step_indices = [self.step_index] * timestep.shape[0]
149
+ else:
150
+ # add noise is called before first denoising step to create initial latent(img2img)
151
+ step_indices = [self.begin_index] * timestep.shape[0]
152
+
153
+ sigma = sigmas[step_indices].flatten()
154
+ while len(sigma.shape) < len(sample.shape):
155
+ sigma = sigma.unsqueeze(-1)
156
+
157
+ sample = sigma * noise + (1.0 - sigma) * sample
158
+
159
+ return sample
160
+
161
+ def _sigma_to_t(self, sigma):
162
+ return sigma * self.config.num_train_timesteps
163
+
164
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
165
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
166
+
167
+ def set_timesteps(
168
+ self,
169
+ num_inference_steps: int = None,
170
+ device: Union[str, torch.device] = None,
171
+ sigmas: Optional[List[float]] = None,
172
+ mu: Optional[float] = None,
173
+ ):
174
+ """
175
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
176
+
177
+ Args:
178
+ num_inference_steps (`int`):
179
+ The number of diffusion steps used when generating samples with a pre-trained model.
180
+ device (`str` or `torch.device`, *optional*):
181
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
182
+ """
183
+
184
+ if self.config.use_dynamic_shifting and mu is None:
185
+ raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
186
+
187
+ if sigmas is None:
188
+ self.num_inference_steps = num_inference_steps
189
+ timesteps = np.linspace(
190
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
191
+ )
192
+
193
+ sigmas = timesteps / self.config.num_train_timesteps
194
+
195
+ if self.config.use_dynamic_shifting:
196
+ sigmas = self.time_shift(mu, 1.0, sigmas)
197
+ else:
198
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
199
+
200
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
201
+ timesteps = sigmas * self.config.num_train_timesteps
202
+
203
+ self.timesteps = timesteps.to(device=device)
204
+ self.sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
205
+
206
+ self._step_index = None
207
+ self._begin_index = None
208
+
209
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
210
+ if schedule_timesteps is None:
211
+ schedule_timesteps = self.timesteps
212
+
213
+ indices = (schedule_timesteps == timestep).nonzero()
214
+
215
+ # The sigma index that is taken for the **very** first `step`
216
+ # is always the second index (or the last index if there is only 1)
217
+ # This way we can ensure we don't accidentally skip a sigma in
218
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
219
+ pos = 1 if len(indices) > 1 else 0
220
+
221
+ return indices[pos].item()
222
+
223
+ def _init_step_index(self, timestep):
224
+ if self.begin_index is None:
225
+ if isinstance(timestep, torch.Tensor):
226
+ timestep = timestep.to(self.timesteps.device)
227
+ self._step_index = self.index_for_timestep(timestep)
228
+ else:
229
+ self._step_index = self._begin_index
230
+
231
+ def step(
232
+ self,
233
+ model_output: torch.FloatTensor,
234
+ timestep: Union[float, torch.FloatTensor],
235
+ sample: torch.FloatTensor,
236
+ s_churn: float = 0.0,
237
+ s_tmin: float = 0.0,
238
+ s_tmax: float = float("inf"),
239
+ s_noise: float = 1.0,
240
+ generator: Optional[torch.Generator] = None,
241
+ return_dict: bool = True,
242
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
243
+ """
244
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
245
+ process from the learned model outputs (most often the predicted noise).
246
+
247
+ Args:
248
+ model_output (`torch.FloatTensor`):
249
+ The direct output from learned diffusion model.
250
+ timestep (`float`):
251
+ The current discrete timestep in the diffusion chain.
252
+ sample (`torch.FloatTensor`):
253
+ A current instance of a sample created by the diffusion process.
254
+ s_churn (`float`):
255
+ s_tmin (`float`):
256
+ s_tmax (`float`):
257
+ s_noise (`float`, defaults to 1.0):
258
+ Scaling factor for noise added to the sample.
259
+ generator (`torch.Generator`, *optional*):
260
+ A random number generator.
261
+ return_dict (`bool`):
262
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
263
+ tuple.
264
+
265
+ Returns:
266
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
267
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
268
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
269
+ """
270
+
271
+ if (
272
+ isinstance(timestep, int)
273
+ or isinstance(timestep, torch.IntTensor)
274
+ or isinstance(timestep, torch.LongTensor)
275
+ ):
276
+ raise ValueError(
277
+ (
278
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
279
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
280
+ " one of the `scheduler.timesteps` as a timestep."
281
+ ),
282
+ )
283
+
284
+ if self.step_index is None:
285
+ self._init_step_index(timestep)
286
+
287
+ # Upcast to avoid precision issues when computing prev_sample
288
+ sample = sample.to(torch.float32)
289
+
290
+ sigma = self.sigmas[self.step_index]
291
+ sigma_next = self.sigmas[self.step_index + 1]
292
+
293
+ prev_sample = sample + (sigma_next - sigma) * model_output
294
+
295
+ # Cast sample back to model compatible dtype
296
+ prev_sample = prev_sample.to(model_output.dtype)
297
+
298
+ # upon completion increase step index by one
299
+ self._step_index += 1
300
+
301
+ if not return_dict:
302
+ return (prev_sample,)
303
+
304
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
305
+
306
+ def __len__(self):
307
+ return self.config.num_train_timesteps