Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +2 -0
inference/flovd_demo.py
CHANGED
@@ -300,6 +300,8 @@ def patch_prepare_latents_safe():
|
|
300 |
|
301 |
if latents is None:
|
302 |
# noise = torch.randn_like(image_latents, generator=generator)
|
|
|
|
|
303 |
noise = torch.randn(
|
304 |
image_latents.shape,
|
305 |
dtype=image_latents.dtype,
|
|
|
300 |
|
301 |
if latents is None:
|
302 |
# noise = torch.randn_like(image_latents, generator=generator)
|
303 |
+
if generator.device != image_latents.device:
|
304 |
+
generator = generator.to(image_latents.device)
|
305 |
noise = torch.randn(
|
306 |
image_latents.shape,
|
307 |
dtype=image_latents.dtype,
|