vilarin commited on
Commit
17931cc
·
verified ·
1 Parent(s): 038406c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -80,7 +80,7 @@ class ModelWrapper:
80
  @spaces.GPU()
81
  def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):
82
  alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
83
-
84
  if self.num_step == 1:
85
  all_timesteps = [self.conditioning_timestep]
86
  step_interval = 0
@@ -90,9 +90,9 @@ class ModelWrapper:
90
  else:
91
  raise NotImplementedError()
92
 
93
- noise = noise.to(device="cuda", dtype=torch.float16)
94
  print(f'noise: {noise.dtype}')
95
- prompt_embed = prompt_embed.to(device="cuda", dtype=torch.float16)
96
  DTYPE = prompt_embed.dtype
97
  print(f'prompt_embed: {DTYPE}')
98
 
@@ -145,10 +145,11 @@ class ModelWrapper:
145
  }
146
 
147
 
148
-
149
  print(f'prompt: {batch_prompt_embeds.dtype}')
150
  print(unet_added_conditions['time_ids'].dtype)
151
  print(unet_added_conditions['text_embeds'].dtype)
 
152
 
153
  eval_images = self.sample(noise=noise, unet_added_conditions=unet_added_conditions, prompt_embed=batch_prompt_embeds, fast_vae_decode=fast_vae_decode)
154
 
 
80
  @spaces.GPU()
81
  def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):
82
  alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
83
+ print(alphas_cumprod)
84
  if self.num_step == 1:
85
  all_timesteps = [self.conditioning_timestep]
86
  step_interval = 0
 
90
  else:
91
  raise NotImplementedError()
92
 
93
+ noise = noise.to(torch.float16)
94
  print(f'noise: {noise.dtype}')
95
+ prompt_embed = prompt_embed.to(torch.float16)
96
  DTYPE = prompt_embed.dtype
97
  print(f'prompt_embed: {DTYPE}')
98
 
 
145
  }
146
 
147
 
148
+ print(f'noise: {noise.dtype}')
149
  print(f'prompt: {batch_prompt_embeds.dtype}')
150
  print(unet_added_conditions['time_ids'].dtype)
151
  print(unet_added_conditions['text_embeds'].dtype)
152
+ print("________")
153
 
154
  eval_images = self.sample(noise=noise, unet_added_conditions=unet_added_conditions, prompt_embed=batch_prompt_embeds, fast_vae_decode=fast_vae_decode)
155