Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +17 -0
inference/flovd_demo.py
CHANGED
@@ -37,6 +37,8 @@ import numpy as np
|
|
37 |
from PIL import Image
|
38 |
|
39 |
import torch
|
|
|
|
|
40 |
|
41 |
from diffusers import (
|
42 |
CogVideoXDPMScheduler,
|
@@ -305,6 +307,21 @@ def generate_video(
|
|
305 |
- seed (int): The seed for reproducibility.
|
306 |
- fps (int): The frames per second for the generated video.
|
307 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
print("at generate video", flush=True)
|
309 |
local_rank = 'cuda'
|
310 |
|
|
|
37 |
from PIL import Image
|
38 |
|
39 |
import torch
|
40 |
+
import types
|
41 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
42 |
|
43 |
from diffusers import (
|
44 |
CogVideoXDPMScheduler,
|
|
|
307 |
- seed (int): The seed for reproducibility.
|
308 |
- fps (int): The frames per second for the generated video.
|
309 |
"""
|
310 |
+
|
311 |
+
def patch_prepare_latents_with_device():
|
312 |
+
original_prepare_latents = CogVideoXImageToVideoPipeline.prepare_latents
|
313 |
+
|
314 |
+
def prepare_latents_with_device(self, *args, **kwargs):
|
315 |
+
result = original_prepare_latents(self, *args, **kwargs)
|
316 |
+
# Ensure returned tensors are moved to the correct device
|
317 |
+
if isinstance(result, tuple):
|
318 |
+
result = tuple(t.to(self.device) if isinstance(t, torch.Tensor) else t for t in result)
|
319 |
+
elif isinstance(result, torch.Tensor):
|
320 |
+
result = result.to(self.device)
|
321 |
+
return result
|
322 |
+
|
323 |
+
CogVideoXImageToVideoPipeline.prepare_latents = types.MethodType(prepare_latents_with_device, CogVideoXImageToVideoPipeline)
|
324 |
+
|
325 |
print("at generate video", flush=True)
|
326 |
local_rank = 'cuda'
|
327 |
|