roll-ai commited on
Commit
e426cb8
·
verified ·
1 Parent(s): 173f5b1

Update inference/flovd_demo.py

Browse files
Files changed (1) hide show
  1. inference/flovd_demo.py +17 -0
inference/flovd_demo.py CHANGED
@@ -37,6 +37,8 @@ import numpy as np
37
  from PIL import Image
38
 
39
  import torch
 
 
40
 
41
  from diffusers import (
42
  CogVideoXDPMScheduler,
@@ -305,6 +307,21 @@ def generate_video(
305
  - seed (int): The seed for reproducibility.
306
  - fps (int): The frames per second for the generated video.
307
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  print("at generate video", flush=True)
309
  local_rank = 'cuda'
310
 
 
37
  from PIL import Image
38
 
39
  import torch
40
+ import types
41
+ from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
42
 
43
  from diffusers import (
44
  CogVideoXDPMScheduler,
 
307
  - seed (int): The seed for reproducibility.
308
  - fps (int): The frames per second for the generated video.
309
  """
310
+
311
+ def patch_prepare_latents_with_device():
312
+ original_prepare_latents = CogVideoXImageToVideoPipeline.prepare_latents
313
+
314
+ def prepare_latents_with_device(self, *args, **kwargs):
315
+ result = original_prepare_latents(self, *args, **kwargs)
316
+ # Ensure returned tensors are moved to the correct device
317
+ if isinstance(result, tuple):
318
+ result = tuple(t.to(self.device) if isinstance(t, torch.Tensor) else t for t in result)
319
+ elif isinstance(result, torch.Tensor):
320
+ result = result.to(self.device)
321
+ return result
322
+
323
+ CogVideoXImageToVideoPipeline.prepare_latents = types.MethodType(prepare_latents_with_device, CogVideoXImageToVideoPipeline)
324
+
325
  print("at generate video", flush=True)
326
  local_rank = 'cuda'
327