dbaranchuk commited on
Commit
2527be9
·
verified ·
1 Parent(s): ec87bf7

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +3 -46
inference.py CHANGED
@@ -21,43 +21,6 @@ else:
21
  XLA_AVAILABLE = False
22
 
23
 
24
- def retrieve_timesteps(
25
- scheduler,
26
- num_inference_steps: Optional[int] = None,
27
- device: Optional[Union[str, torch.device]] = None,
28
- timesteps: Optional[List[int]] = None,
29
- sigmas: Optional[List[float]] = None,
30
- **kwargs,
31
- ):
32
- if timesteps is not None and sigmas is not None:
33
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
34
- if timesteps is not None:
35
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
36
- if not accepts_timesteps:
37
- raise ValueError(
38
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
39
- f" timestep schedules. Please check whether you are using the correct scheduler."
40
- )
41
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
42
- timesteps = scheduler.timesteps
43
- num_inference_steps = len(timesteps)
44
- elif sigmas is not None:
45
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
46
- if not accept_sigmas:
47
- raise ValueError(
48
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
49
- f" sigmas schedules. Please check whether you are using the correct scheduler."
50
- )
51
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
52
- timesteps = scheduler.timesteps
53
- num_inference_steps = len(timesteps)
54
- else:
55
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
56
- timesteps = scheduler.timesteps
57
-
58
- return timesteps, num_inference_steps
59
-
60
-
61
  @torch.no_grad()
62
  def run(
63
  self,
@@ -68,6 +31,7 @@ def run(
68
  width: Optional[int] = None,
69
  num_inference_steps: int = 28,
70
  sigmas: Optional[List[float]] = None,
 
71
  scales: List[float] = None,
72
  guidance_scale: float = 7.0,
73
  negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -196,13 +160,6 @@ def run(
196
  scheduler_kwargs["mu"] = mu
197
  elif mu is not None:
198
  scheduler_kwargs["mu"] = mu
199
- timesteps, num_inference_steps = retrieve_timesteps(
200
- self.scheduler,
201
- num_inference_steps,
202
- device,
203
- sigmas=sigmas,
204
- **scheduler_kwargs,
205
- )
206
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
207
  self._num_timesteps = len(timesteps)
208
 
@@ -269,8 +226,8 @@ def run(
269
 
270
  # compute the previous noisy sample x_t -> x_t-1
271
  latents_dtype = latents.dtype
272
- sigma = self.scheduler.sigmas[i]
273
- sigma_next = self.scheduler.sigmas[i + 1]
274
  x0_pred = (latents - sigma * noise_pred)
275
  try:
276
  x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1])
 
21
  XLA_AVAILABLE = False
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @torch.no_grad()
25
  def run(
26
  self,
 
31
  width: Optional[int] = None,
32
  num_inference_steps: int = 28,
33
  sigmas: Optional[List[float]] = None,
34
+ timesteps: Optional[List[float]] = None,
35
  scales: List[float] = None,
36
  guidance_scale: float = 7.0,
37
  negative_prompt: Optional[Union[str, List[str]]] = None,
 
160
  scheduler_kwargs["mu"] = mu
161
  elif mu is not None:
162
  scheduler_kwargs["mu"] = mu
 
 
 
 
 
 
 
163
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
164
  self._num_timesteps = len(timesteps)
165
 
 
226
 
227
  # compute the previous noisy sample x_t -> x_t-1
228
  latents_dtype = latents.dtype
229
+ sigma = sigmas[i]
230
+ sigma_next = sigmas[i + 1]
231
  x0_pred = (latents - sigma * noise_pred)
232
  try:
233
  x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1])