roll-ai commited on
Commit
2dd8067
·
verified ·
1 Parent(s): 2bf9932

Update inference/flovd_demo.py

Browse files
Files changed (1) hide show
  1. inference/flovd_demo.py +24 -7
inference/flovd_demo.py CHANGED
@@ -264,28 +264,45 @@ def save_flow_warped_video(image, flow, filename, fps=16):
264
  frame_list.append(Image.fromarray(frame))
265
 
266
  export_to_video(frame_list, filename, fps=fps)
 
267
 
268
  from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
269
-
270
  def patch_prepare_latents_safe():
271
- def new_prepare_latents(self, image, num_frames, height, width, batch_size, dtype, generator, do_classifier_free_guidance=False):
272
- image_latents = self.vae.encode(image.to(self.device, dtype=dtype)).latent_dist.sample()
 
 
 
 
 
 
 
 
 
 
 
 
273
  image_latents = image_latents * self.vae.config.scaling_factor
274
 
 
275
  if image_latents.shape[2] != num_frames:
276
  latent_padding = torch.zeros(
277
  (image_latents.shape[0], num_frames - image_latents.shape[2], image_latents.shape[3], image_latents.shape[4]),
278
- device=image_latents.device, dtype=image_latents.dtype
 
279
  )
280
  image_latents = torch.cat([image_latents, latent_padding], dim=1)
281
 
282
- noise = torch.randn_like(image_latents, generator=generator)
283
- latents = noise.to(self.device, dtype=dtype)
 
284
 
285
- return latents, image_latents.to(self.device, dtype=dtype)
286
 
 
287
  CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
288
 
 
289
  def generate_video(
290
  prompt: str,
291
  fvsm_path: str,
 
264
  frame_list.append(Image.fromarray(frame))
265
 
266
  export_to_video(frame_list, filename, fps=fps)
267
+
268
 
269
  from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
 
270
  def patch_prepare_latents_safe():
271
+ def new_prepare_latents(
272
+ self,
273
+ image,
274
+ batch_size,
275
+ latent_channels,
276
+ num_frames,
277
+ height,
278
+ width,
279
+ dtype,
280
+ device,
281
+ generator,
282
+ latents=None,
283
+ ):
284
+ image_latents = self.vae.encode(image.to(device, dtype=dtype)).latent_dist.sample()
285
  image_latents = image_latents * self.vae.config.scaling_factor
286
 
287
+ # Pad temporal dimension if needed
288
  if image_latents.shape[2] != num_frames:
289
  latent_padding = torch.zeros(
290
  (image_latents.shape[0], num_frames - image_latents.shape[2], image_latents.shape[3], image_latents.shape[4]),
291
+ device=image_latents.device,
292
+ dtype=image_latents.dtype
293
  )
294
  image_latents = torch.cat([image_latents, latent_padding], dim=1)
295
 
296
+ if latents is None:
297
+ noise = torch.randn_like(image_latents, generator=generator)
298
+ latents = noise.to(device=device, dtype=dtype)
299
 
300
+ return latents, image_latents.to(device, dtype=dtype)
301
 
302
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
303
  CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
304
 
305
+
306
  def generate_video(
307
  prompt: str,
308
  fvsm_path: str,