Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -77,7 +77,7 @@ class ModelWrapper:
|
|
| 77 |
def _get_time():
|
| 78 |
return time.time()
|
| 79 |
|
| 80 |
-
@spaces.GPU()
|
| 81 |
def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):
|
| 82 |
alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
| 83 |
|
|
@@ -89,13 +89,15 @@ class ModelWrapper:
|
|
| 89 |
step_interval = 250
|
| 90 |
else:
|
| 91 |
raise NotImplementedError()
|
| 92 |
-
|
|
|
|
|
|
|
| 93 |
DTYPE = prompt_embed.dtype
|
| 94 |
print(f'prompt_embed: {DTYPE}')
|
| 95 |
|
| 96 |
for constant in all_timesteps:
|
| 97 |
current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
|
| 98 |
-
current_timesteps = current_timesteps.to(torch.float16)
|
| 99 |
print(f'current_timestpes: {current_timesteps.dtype}')
|
| 100 |
eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
|
| 101 |
print(type(eval_images))
|
|
@@ -124,7 +126,7 @@ class ModelWrapper:
|
|
| 124 |
|
| 125 |
add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
|
| 126 |
|
| 127 |
-
noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator)
|
| 128 |
|
| 129 |
prompt_inputs = self._encode_prompt(prompt)
|
| 130 |
|
|
@@ -143,7 +145,7 @@ class ModelWrapper:
|
|
| 143 |
}
|
| 144 |
|
| 145 |
|
| 146 |
-
|
| 147 |
print(f'prompt: {batch_prompt_embeds.dtype}')
|
| 148 |
print(unet_added_conditions['time_ids'].dtype)
|
| 149 |
print(unet_added_conditions['text_embeds'].dtype)
|
|
|
|
| 77 |
def _get_time():
|
| 78 |
return time.time()
|
| 79 |
|
| 80 |
+
@spaces.GPU(duration=100)
|
| 81 |
def sample(self, noise, unet_added_conditions, prompt_embed, fast_vae_decode):
|
| 82 |
alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
| 83 |
|
|
|
|
| 89 |
step_interval = 250
|
| 90 |
else:
|
| 91 |
raise NotImplementedError()
|
| 92 |
+
|
| 93 |
+
noise = noise.to(device="cuda", dtype=torch.float16)
|
| 94 |
+
print(f'noise: {noise.dtype}')
|
| 95 |
DTYPE = prompt_embed.dtype
|
| 96 |
print(f'prompt_embed: {DTYPE}')
|
| 97 |
|
| 98 |
for constant in all_timesteps:
|
| 99 |
current_timesteps = torch.ones(len(prompt_embed), device="cuda", dtype=torch.long) * constant
|
| 100 |
+
#current_timesteps = current_timesteps.to(torch.float16)
|
| 101 |
print(f'current_timestpes: {current_timesteps.dtype}')
|
| 102 |
eval_images = self.model(noise, current_timesteps, prompt_embed, added_cond_kwargs=unet_added_conditions).sample
|
| 103 |
print(type(eval_images))
|
|
|
|
| 126 |
|
| 127 |
add_time_ids = self.build_condition_input(height, width).repeat(num_images, 1)
|
| 128 |
|
| 129 |
+
noise = torch.randn(num_images, 4, height // self.vae_downsample_ratio, width // self.vae_downsample_ratio, generator=generator)
|
| 130 |
|
| 131 |
prompt_inputs = self._encode_prompt(prompt)
|
| 132 |
|
|
|
|
| 145 |
}
|
| 146 |
|
| 147 |
|
| 148 |
+
|
| 149 |
print(f'prompt: {batch_prompt_embeds.dtype}')
|
| 150 |
print(unet_added_conditions['time_ids'].dtype)
|
| 151 |
print(unet_added_conditions['text_embeds'].dtype)
|