Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +7 -3
inference/flovd_demo.py
CHANGED
@@ -281,17 +281,20 @@ def patch_prepare_latents_safe():
|
|
281 |
generator,
|
282 |
latents=None,
|
283 |
):
|
284 |
-
|
|
|
|
|
|
|
285 |
image_latents = image_latents * self.vae.config.scaling_factor
|
286 |
|
287 |
# Pad temporal dimension if needed
|
288 |
if image_latents.shape[2] != num_frames:
|
289 |
latent_padding = torch.zeros(
|
290 |
-
(image_latents.shape[0], num_frames - image_latents.shape[2],
|
291 |
device=image_latents.device,
|
292 |
dtype=image_latents.dtype
|
293 |
)
|
294 |
-
image_latents = torch.cat([image_latents, latent_padding], dim=
|
295 |
|
296 |
if latents is None:
|
297 |
noise = torch.randn_like(image_latents, generator=generator)
|
@@ -303,6 +306,7 @@ def patch_prepare_latents_safe():
|
|
303 |
CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
|
304 |
|
305 |
|
|
|
306 |
def generate_video(
|
307 |
prompt: str,
|
308 |
fvsm_path: str,
|
|
|
281 |
generator,
|
282 |
latents=None,
|
283 |
):
|
284 |
+
# Ensure 5D input: [B, C, F=1, H, W]
|
285 |
+
image_5d = image.unsqueeze(2) if image.ndim == 4 else image
|
286 |
+
|
287 |
+
image_latents = self.vae.encode(image_5d.to(device, dtype=dtype)).latent_dist.sample()
|
288 |
image_latents = image_latents * self.vae.config.scaling_factor
|
289 |
|
290 |
# Pad temporal dimension if needed
|
291 |
if image_latents.shape[2] != num_frames:
|
292 |
latent_padding = torch.zeros(
|
293 |
+
(image_latents.shape[0], latent_channels, num_frames - image_latents.shape[2], height, width),
|
294 |
device=image_latents.device,
|
295 |
dtype=image_latents.dtype
|
296 |
)
|
297 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=2)
|
298 |
|
299 |
if latents is None:
|
300 |
noise = torch.randn_like(image_latents, generator=generator)
|
|
|
306 |
CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
|
307 |
|
308 |
|
309 |
+
|
310 |
def generate_video(
|
311 |
prompt: str,
|
312 |
fvsm_path: str,
|