roll-ai commited on
Commit
2bf9932
·
verified ·
1 Parent(s): e426cb8

Update inference/flovd_demo.py

Browse files
Files changed (1) hide show
  1. inference/flovd_demo.py +22 -13
inference/flovd_demo.py CHANGED
@@ -264,6 +264,27 @@ def save_flow_warped_video(image, flow, filename, fps=16):
264
  frame_list.append(Image.fromarray(frame))
265
 
266
  export_to_video(frame_list, filename, fps=fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  def generate_video(
269
  prompt: str,
@@ -308,19 +329,7 @@ def generate_video(
308
  - fps (int): The frames per second for the generated video.
309
  """
310
 
311
- def patch_prepare_latents_with_device():
312
- original_prepare_latents = CogVideoXImageToVideoPipeline.prepare_latents
313
-
314
- def prepare_latents_with_device(self, *args, **kwargs):
315
- result = original_prepare_latents(self, *args, **kwargs)
316
- # Ensure returned tensors are moved to the correct device
317
- if isinstance(result, tuple):
318
- result = tuple(t.to(self.device) if isinstance(t, torch.Tensor) else t for t in result)
319
- elif isinstance(result, torch.Tensor):
320
- result = result.to(self.device)
321
- return result
322
-
323
- CogVideoXImageToVideoPipeline.prepare_latents = types.MethodType(prepare_latents_with_device, CogVideoXImageToVideoPipeline)
324
 
325
  print("at generate video", flush=True)
326
  local_rank = 'cuda'
 
264
  frame_list.append(Image.fromarray(frame))
265
 
266
  export_to_video(frame_list, filename, fps=fps)
267
+
268
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
269
+
270
+ def patch_prepare_latents_safe():
271
+ def new_prepare_latents(self, image, num_frames, height, width, batch_size, dtype, generator, do_classifier_free_guidance=False):
272
+ image_latents = self.vae.encode(image.to(self.device, dtype=dtype)).latent_dist.sample()
273
+ image_latents = image_latents * self.vae.config.scaling_factor
274
+
275
+ if image_latents.shape[2] != num_frames:
276
+ latent_padding = torch.zeros(
277
+ (image_latents.shape[0], num_frames - image_latents.shape[2], image_latents.shape[3], image_latents.shape[4]),
278
+ device=image_latents.device, dtype=image_latents.dtype
279
+ )
280
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
281
+
282
+ noise = torch.randn_like(image_latents, generator=generator)
283
+ latents = noise.to(self.device, dtype=dtype)
284
+
285
+ return latents, image_latents.to(self.device, dtype=dtype)
286
+
287
+ CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
288
 
289
  def generate_video(
290
  prompt: str,
 
329
  - fps (int): The frames per second for the generated video.
330
  """
331
 
332
+ patch_prepare_latents_safe()
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  print("at generate video", flush=True)
335
  local_rank = 'cuda'