Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +8 -1
inference/flovd_demo.py
CHANGED
@@ -299,7 +299,14 @@ def patch_prepare_latents_safe():
|
|
299 |
image_latents = torch.cat([image_latents, latent_padding], dim=2)
|
300 |
|
301 |
if latents is None:
|
302 |
-
noise = torch.randn_like(image_latents, generator=generator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
latents = noise.to(device=device, dtype=dtype)
|
304 |
|
305 |
return latents, image_latents.to(device, dtype=dtype)
|
|
|
299 |
image_latents = torch.cat([image_latents, latent_padding], dim=2)
|
300 |
|
301 |
if latents is None:
|
302 |
+
# noise = torch.randn_like(image_latents, generator=generator)
|
303 |
+
noise = torch.randn(
|
304 |
+
image_latents.shape,
|
305 |
+
dtype=image_latents.dtype,
|
306 |
+
device=image_latents.device,
|
307 |
+
generator=generator
|
308 |
+
)
|
309 |
+
|
310 |
latents = noise.to(device=device, dtype=dtype)
|
311 |
|
312 |
return latents, image_latents.to(device, dtype=dtype)
|