Barak1 commited on
Commit
fc20098
·
1 Parent(s): 159f668

comment out casting to float32

Browse files
Files changed (1) hide show
  1. src/euler_scheduler.py +4 -4
src/euler_scheduler.py CHANGED
@@ -120,7 +120,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
120
  sigma = self.sigmas[self.step_index]
121
 
122
  # Upcast to avoid precision issues when computing prev_sample
123
- sample = sample.to(torch.float32)
124
 
125
  # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
126
  if self.config.prediction_type == "epsilon":
@@ -226,7 +226,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
226
  sigma = self.sigmas[self.step_index]
227
 
228
  # Upcast to avoid precision issues when computing prev_sample
229
- sample = sample.to(torch.float32)
230
 
231
  # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
232
  if self.config.prediction_type == "epsilon":
@@ -394,7 +394,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
394
 
395
  def get_all_sigmas(self) -> torch.FloatTensor:
396
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
397
- sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
398
  return torch.from_numpy(sigmas)
399
 
400
  def add_noise_off_schedule(
@@ -408,7 +408,7 @@ class MyEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
408
  sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
409
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
410
  # mps does not support float64
411
- timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
412
  else:
413
  timesteps = timesteps.to(original_samples.device)
414
 
 
120
  sigma = self.sigmas[self.step_index]
121
 
122
  # Upcast to avoid precision issues when computing prev_sample
123
+ # sample = sample.to(torch.float32)
124
 
125
  # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
126
  if self.config.prediction_type == "epsilon":
 
226
  sigma = self.sigmas[self.step_index]
227
 
228
  # Upcast to avoid precision issues when computing prev_sample
229
+ # sample = sample.to(torch.float32)
230
 
231
  # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
232
  if self.config.prediction_type == "epsilon":
 
394
 
395
  def get_all_sigmas(self) -> torch.FloatTensor:
396
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
397
+ sigmas = np.concatenate([sigmas[::-1], [0.0]])#.astype(np.float32)
398
  return torch.from_numpy(sigmas)
399
 
400
  def add_noise_off_schedule(
 
408
  sigmas = sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
409
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
410
  # mps does not support float64
411
+ timesteps = timesteps.to(original_samples)#.device, dtype=torch.float32)
412
  else:
413
  timesteps = timesteps.to(original_samples.device)
414