roll-ai commited on
Commit
0811cd2
·
verified ·
1 Parent(s): b7df3f6

Update inference/flovd_demo.py

Browse files
Files changed (1) hide show
  1. inference/flovd_demo.py +7 -3
inference/flovd_demo.py CHANGED
@@ -281,17 +281,20 @@ def patch_prepare_latents_safe():
281
  generator,
282
  latents=None,
283
  ):
284
- image_latents = self.vae.encode(image.to(device, dtype=dtype)).latent_dist.sample()
 
 
 
285
  image_latents = image_latents * self.vae.config.scaling_factor
286
 
287
  # Pad temporal dimension if needed
288
  if image_latents.shape[2] != num_frames:
289
  latent_padding = torch.zeros(
290
- (image_latents.shape[0], num_frames - image_latents.shape[2], image_latents.shape[3], image_latents.shape[4]),
291
  device=image_latents.device,
292
  dtype=image_latents.dtype
293
  )
294
- image_latents = torch.cat([image_latents, latent_padding], dim=1)
295
 
296
  if latents is None:
297
  noise = torch.randn_like(image_latents, generator=generator)
@@ -303,6 +306,7 @@ def patch_prepare_latents_safe():
303
  CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
304
 
305
 
 
306
  def generate_video(
307
  prompt: str,
308
  fvsm_path: str,
 
281
  generator,
282
  latents=None,
283
  ):
284
+ # Ensure 5D input: [B, C, F=1, H, W]
285
+ image_5d = image.unsqueeze(2) if image.ndim == 4 else image
286
+
287
+ image_latents = self.vae.encode(image_5d.to(device, dtype=dtype)).latent_dist.sample()
288
  image_latents = image_latents * self.vae.config.scaling_factor
289
 
290
  # Pad temporal dimension if needed
291
  if image_latents.shape[2] != num_frames:
292
  latent_padding = torch.zeros(
293
+ (image_latents.shape[0], latent_channels, num_frames - image_latents.shape[2], height, width),
294
  device=image_latents.device,
295
  dtype=image_latents.dtype
296
  )
297
+ image_latents = torch.cat([image_latents, latent_padding], dim=2)
298
 
299
  if latents is None:
300
  noise = torch.randn_like(image_latents, generator=generator)
 
306
  CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
307
 
308
 
309
+
310
  def generate_video(
311
  prompt: str,
312
  fvsm_path: str,