Spaces:
Running
on
Zero
Running
on
Zero
Update inference.py
Browse files- inference.py +3 -46
inference.py
CHANGED
@@ -21,43 +21,6 @@ else:
|
|
21 |
XLA_AVAILABLE = False
|
22 |
|
23 |
|
24 |
-
def retrieve_timesteps(
|
25 |
-
scheduler,
|
26 |
-
num_inference_steps: Optional[int] = None,
|
27 |
-
device: Optional[Union[str, torch.device]] = None,
|
28 |
-
timesteps: Optional[List[int]] = None,
|
29 |
-
sigmas: Optional[List[float]] = None,
|
30 |
-
**kwargs,
|
31 |
-
):
|
32 |
-
if timesteps is not None and sigmas is not None:
|
33 |
-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
34 |
-
if timesteps is not None:
|
35 |
-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
36 |
-
if not accepts_timesteps:
|
37 |
-
raise ValueError(
|
38 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
39 |
-
f" timestep schedules. Please check whether you are using the correct scheduler."
|
40 |
-
)
|
41 |
-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
42 |
-
timesteps = scheduler.timesteps
|
43 |
-
num_inference_steps = len(timesteps)
|
44 |
-
elif sigmas is not None:
|
45 |
-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
46 |
-
if not accept_sigmas:
|
47 |
-
raise ValueError(
|
48 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
49 |
-
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
50 |
-
)
|
51 |
-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
52 |
-
timesteps = scheduler.timesteps
|
53 |
-
num_inference_steps = len(timesteps)
|
54 |
-
else:
|
55 |
-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
56 |
-
timesteps = scheduler.timesteps
|
57 |
-
|
58 |
-
return timesteps, num_inference_steps
|
59 |
-
|
60 |
-
|
61 |
@torch.no_grad()
|
62 |
def run(
|
63 |
self,
|
@@ -68,6 +31,7 @@ def run(
|
|
68 |
width: Optional[int] = None,
|
69 |
num_inference_steps: int = 28,
|
70 |
sigmas: Optional[List[float]] = None,
|
|
|
71 |
scales: List[float] = None,
|
72 |
guidance_scale: float = 7.0,
|
73 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
@@ -196,13 +160,6 @@ def run(
|
|
196 |
scheduler_kwargs["mu"] = mu
|
197 |
elif mu is not None:
|
198 |
scheduler_kwargs["mu"] = mu
|
199 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
200 |
-
self.scheduler,
|
201 |
-
num_inference_steps,
|
202 |
-
device,
|
203 |
-
sigmas=sigmas,
|
204 |
-
**scheduler_kwargs,
|
205 |
-
)
|
206 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
207 |
self._num_timesteps = len(timesteps)
|
208 |
|
@@ -269,8 +226,8 @@ def run(
|
|
269 |
|
270 |
# compute the previous noisy sample x_t -> x_t-1
|
271 |
latents_dtype = latents.dtype
|
272 |
-
sigma =
|
273 |
-
sigma_next =
|
274 |
x0_pred = (latents - sigma * noise_pred)
|
275 |
try:
|
276 |
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1])
|
|
|
21 |
XLA_AVAILABLE = False
|
22 |
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
@torch.no_grad()
|
25 |
def run(
|
26 |
self,
|
|
|
31 |
width: Optional[int] = None,
|
32 |
num_inference_steps: int = 28,
|
33 |
sigmas: Optional[List[float]] = None,
|
34 |
+
timesteps: Optional[List[float]] = None,
|
35 |
scales: List[float] = None,
|
36 |
guidance_scale: float = 7.0,
|
37 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
|
160 |
scheduler_kwargs["mu"] = mu
|
161 |
elif mu is not None:
|
162 |
scheduler_kwargs["mu"] = mu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
164 |
self._num_timesteps = len(timesteps)
|
165 |
|
|
|
226 |
|
227 |
# compute the previous noisy sample x_t -> x_t-1
|
228 |
latents_dtype = latents.dtype
|
229 |
+
sigma = sigmas[i]
|
230 |
+
sigma_next = sigmas[i + 1]
|
231 |
x0_pred = (latents - sigma * noise_pred)
|
232 |
try:
|
233 |
x0_pred = torch.nn.functional.interpolate(x0_pred, size=scales[i + 1])
|