JiminHeo commited on
Commit
25e9acb
·
1 Parent(s): 1e39b03
ldm/guided_diffusion/h_posterior.py CHANGED
@@ -60,10 +60,11 @@ class HPosterior(object):
60
  mean = (var) * ( (alpha_s_t_star/scale_s_t_star**2) * z_t_star + (alpha_t_s/scale_t_s**2) * z_t )
61
  return mean, torch.sqrt(var)
62
 
 
63
  def register_buffer(self, name, attr):
64
- if type(attr) == torch.Tensor:
65
- if attr.device != torch.device("cuda"):
66
- attr = attr.to(torch.device("cuda"))
67
  setattr(self, name, attr)
68
 
69
  def get_error(self,x,t,c, unconditional_conditioning, unconditional_guidance_scale):
 
60
  mean = (var) * ( (alpha_s_t_star/scale_s_t_star**2) * z_t_star + (alpha_t_s/scale_t_s**2) * z_t )
61
  return mean, torch.sqrt(var)
62
 
63
+
64
  def register_buffer(self, name, attr):
65
+ if isinstance(attr, torch.Tensor):
66
+ if not attr.is_cuda:
67
+ attr = attr.cuda()
68
  setattr(self, name, attr)
69
 
70
  def get_error(self,x,t,c, unconditional_conditioning, unconditional_guidance_scale):
ldm/models/diffusion/ddim.py CHANGED
@@ -16,9 +16,8 @@ class DDIMSampler(object):
16
  self.schedule = schedule
17
 
18
  def register_buffer(self, name, attr):
19
- if type(attr) == torch.Tensor:
20
- if attr.device != torch.device("cuda"):
21
- attr = attr.to(torch.device("cuda"))
22
  setattr(self, name, attr)
23
 
24
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
 
16
  self.schedule = schedule
17
 
18
  def register_buffer(self, name, attr):
19
+ if isinstance(attr, torch.Tensor) and not attr.is_cuda:
20
+ attr = attr.cuda()
 
21
  setattr(self, name, attr)
22
 
23
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
ldm/models/diffusion/plms.py CHANGED
@@ -16,9 +16,8 @@ class PLMSSampler(object):
16
  self.schedule = schedule
17
 
18
  def register_buffer(self, name, attr):
19
- if type(attr) == torch.Tensor:
20
- if attr.device != torch.device("cuda"):
21
- attr = attr.to(torch.device("cuda"))
22
  setattr(self, name, attr)
23
 
24
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
 
16
  self.schedule = schedule
17
 
18
  def register_buffer(self, name, attr):
19
+ if isinstance(attr, torch.Tensor) and not attr.is_cuda:
20
+ attr = attr.cuda()
 
21
  setattr(self, name, attr)
22
 
23
  def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):