Spaces:
Running
on
A10G
Running
on
A10G
comment out casting to float32
Browse files- src/euler_scheduler.py +4 -4
src/euler_scheduler.py
CHANGED
|
@@ -120,7 +120,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
|
| 120 |
sigma = self.sigmas[self.step_index]
|
| 121 |
|
| 122 |
# Upcast to avoid precision issues when computing prev_sample
|
| 123 |
-
sample = sample.to(torch.float32)
|
| 124 |
|
| 125 |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 126 |
if self.config.prediction_type == "epsilon":
|
|
@@ -226,7 +226,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
|
| 226 |
sigma = self.sigmas[self.step_index]
|
| 227 |
|
| 228 |
# Upcast to avoid precision issues when computing prev_sample
|
| 229 |
-
sample = sample.to(torch.float32)
|
| 230 |
|
| 231 |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 232 |
if self.config.prediction_type == "epsilon":
|
|
@@ -394,7 +394,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
|
| 394 |
|
| 395 |
def get_all_sigmas(self) -> torch.FloatTensor:
|
| 396 |
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
| 397 |
-
sigmas = np.concatenate([sigmas[::-1], [0.0]])
|
| 398 |
return torch.from_numpy(sigmas)
|
| 399 |
|
| 400 |
def add_noise_off_schedule(
|
|
@@ -408,7 +408,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
|
| 408 |
sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 409 |
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 410 |
# mps does not support float64
|
| 411 |
-
timesteps = timesteps.to(original_samples
|
| 412 |
else:
|
| 413 |
timesteps = timesteps.to(original_samples.device)
|
| 414 |
|
|
|
|
| 120 |
sigma = self.sigmas[self.step_index]
|
| 121 |
|
| 122 |
# Upcast to avoid precision issues when computing prev_sample
|
| 123 |
+
# sample = sample.to(torch.float32)
|
| 124 |
|
| 125 |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 126 |
if self.config.prediction_type == "epsilon":
|
|
|
|
| 226 |
sigma = self.sigmas[self.step_index]
|
| 227 |
|
| 228 |
# Upcast to avoid precision issues when computing prev_sample
|
| 229 |
+
# sample = sample.to(torch.float32)
|
| 230 |
|
| 231 |
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
| 232 |
if self.config.prediction_type == "epsilon":
|
|
|
|
| 394 |
|
| 395 |
def get_all_sigmas(self) -> torch.FloatTensor:
|
| 396 |
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
| 397 |
+
sigmas = np.concatenate([sigmas[::-1], [0.0]])#.astype(np.float32)
|
| 398 |
return torch.from_numpy(sigmas)
|
| 399 |
|
| 400 |
def add_noise_off_schedule(
|
|
|
|
| 408 |
sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
| 409 |
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
| 410 |
# mps does not support float64
|
| 411 |
+
timesteps = timesteps.to(original_samples)#.device, dtype=torch.float32)
|
| 412 |
else:
|
| 413 |
timesteps = timesteps.to(original_samples.device)
|
| 414 |
|