vilarin commited on
Commit
5764a43
·
verified ·
1 Parent(s): 4d88a56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -91,9 +91,6 @@ class ModelWrapper:
91
 
92
  DTYPE = prompt_embed.dtype
93
  print(DTYPE)
94
- print(type(noise))
95
- print(type(current_timesteps))
96
- print(type(unet_added_conditions))
97
 
98
  for constant in all_timesteps:
99
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
@@ -124,7 +121,7 @@ class ModelWrapper:
124
 
125
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
126
 
127
- noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device="cuda", dtype=self.DTYPE)
128
 
129
  prompt_inputs = self._encode_prompt(prompt)
130
 
@@ -161,7 +158,7 @@ def get_x0_from_noise(sample, model_output, alphas_cumprod, timestep):
161
  return pred_original_sample
162
 
163
  class SDXLTextEncoder(torch.nn.Module):
164
- def __init__(self, model_id, revision, accelerator, dtype=torch.float32):
165
  super().__init__()
166
 
167
  self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)
 
91
 
92
  DTYPE = prompt_embed.dtype
93
  print(DTYPE)
 
 
 
94
 
95
  for constant in all_timesteps:
96
  current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
 
121
 
122
  add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
123
 
124
+ noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator).to(device="cuda", dtype=float16)
125
 
126
  prompt_inputs = self._encode_prompt(prompt)
127
 
 
158
  return pred_original_sample
159
 
160
  class SDXLTextEncoder(torch.nn.Module):
161
+ def __init__(self, model_id, revision, accelerator, dtype=torch.float16):
162
  super().__init__()
163
 
164
  self.text_encoder_one = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision=revision).to(0).to(dtype=dtype)