Spaces:
Runtime error
Runtime error
comment out casting to float32
Browse files- src/euler_scheduler.py +4 -4
src/euler_scheduler.py
CHANGED
@@ -120,7 +120,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
|
120 |
sigma = self.sigmas[self.step_index]
|
121 |
|
122 |
# Upcast to avoid precision issues when computing prev_sample
|
123 |
-
sample = sample.to(torch.float32)
|
124 |
|
125 |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
126 |
if self.config.prediction_type == "epsilon":
|
@@ -226,7 +226,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
|
226 |
sigma = self.sigmas[self.step_index]
|
227 |
|
228 |
# Upcast to avoid precision issues when computing prev_sample
|
229 |
-
sample = sample.to(torch.float32)
|
230 |
|
231 |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
232 |
if self.config.prediction_type == "epsilon":
|
@@ -394,7 +394,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
|
394 |
|
395 |
def get_all_sigmas(self) -> torch.FloatTensor:
|
396 |
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
397 |
-
sigmas = np.concatenate([sigmas[::-1], [0.0]])
|
398 |
return torch.from_numpy(sigmas)
|
399 |
|
400 |
def add_noise_off_schedule(
|
@@ -408,7 +408,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
|
408 |
sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
409 |
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
410 |
# mps does not support float64
|
411 |
-
timesteps = timesteps.to(original_samples
|
412 |
else:
|
413 |
timesteps = timesteps.to(original_samples.device)
|
414 |
|
|
|
120 |
sigma = self.sigmas[self.step_index]
|
121 |
|
122 |
# Upcast to avoid precision issues when computing prev_sample
|
123 |
+
# sample = sample.to(torch.float32)
|
124 |
|
125 |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
126 |
if self.config.prediction_type == "epsilon":
|
|
|
226 |
sigma = self.sigmas[self.step_index]
|
227 |
|
228 |
# Upcast to avoid precision issues when computing prev_sample
|
229 |
+
# sample = sample.to(torch.float32)
|
230 |
|
231 |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
232 |
if self.config.prediction_type == "epsilon":
|
|
|
394 |
|
395 |
def get_all_sigmas(self) -> torch.FloatTensor:
|
396 |
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
397 |
+
sigmas = np.concatenate([sigmas[::-1], [0.0]])#.astype(np.float32)
|
398 |
return torch.from_numpy(sigmas)
|
399 |
|
400 |
def add_noise_off_schedule(
|
|
|
408 |
sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
409 |
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
410 |
# mps does not support float64
|
411 |
+
timesteps = timesteps.to(original_samples)#.device, dtype=torch.float32)
|
412 |
else:
|
413 |
timesteps = timesteps.to(original_samples.device)
|
414 |
|