Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -156,13 +156,13 @@ def infer(
|
|
| 156 |
print('-- filtered prompt --')
|
| 157 |
print(enhanced_prompt)
|
| 158 |
if latent_file: # Check if a latent file is provided
|
| 159 |
-
initial_latents =
|
| 160 |
batch_size=1,
|
| 161 |
-
num_channels_latents=
|
| 162 |
-
height=
|
| 163 |
-
width=
|
| 164 |
-
dtype=
|
| 165 |
-
device=
|
| 166 |
generator=generator,
|
| 167 |
)
|
| 168 |
sd_image_a = torch.load(latent_file.name) # Load the latent
|
|
@@ -203,13 +203,13 @@ def infer(
|
|
| 203 |
# Encode the generated image into latents
|
| 204 |
with torch.no_grad():
|
| 205 |
generated_latents = vae.encode(generated_image_tensor.to(torch.bfloat16)).latent_dist.sample().mul_(0.18215)
|
| 206 |
-
initial_latents =
|
| 207 |
batch_size=1,
|
| 208 |
-
num_channels_latents=
|
| 209 |
-
height=
|
| 210 |
-
width=
|
| 211 |
-
dtype=
|
| 212 |
-
device=
|
| 213 |
generator=generator,
|
| 214 |
)
|
| 215 |
initial_latents += generated_latents
|
|
|
|
| 156 |
print('-- filtered prompt --')
|
| 157 |
print(enhanced_prompt)
|
| 158 |
if latent_file: # Check if a latent file is provided
|
| 159 |
+
initial_latents = pipe.prepare_latents(
|
| 160 |
batch_size=1,
|
| 161 |
+
num_channels_latents=pipe.unet.in_channels,
|
| 162 |
+
height=pipe.unet.sample_size[0],
|
| 163 |
+
width=pipe.unet.sample_size[1],
|
| 164 |
+
dtype=pipe.unet.dtype,
|
| 165 |
+
device=pipe.device,
|
| 166 |
generator=generator,
|
| 167 |
)
|
| 168 |
sd_image_a = torch.load(latent_file.name) # Load the latent
|
|
|
|
| 203 |
# Encode the generated image into latents
|
| 204 |
with torch.no_grad():
|
| 205 |
generated_latents = vae.encode(generated_image_tensor.to(torch.bfloat16)).latent_dist.sample().mul_(0.18215)
|
| 206 |
+
initial_latents = pipe.prepare_latents(
|
| 207 |
batch_size=1,
|
| 208 |
+
num_channels_latents=pipe.unet.in_channels,
|
| 209 |
+
height=pipe.unet.sample_size[0],
|
| 210 |
+
width=pipe.unet.sample_size[1],
|
| 211 |
+
dtype=pipe.unet.dtype,
|
| 212 |
+
device=pipe.device,
|
| 213 |
generator=generator,
|
| 214 |
)
|
| 215 |
initial_latents += generated_latents
|