Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +22 -13
inference/flovd_demo.py
CHANGED
@@ -264,6 +264,27 @@ def save_flow_warped_video(image, flow, filename, fps=16):
|
|
264 |
frame_list.append(Image.fromarray(frame))
|
265 |
|
266 |
export_to_video(frame_list, filename, fps=fps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
def generate_video(
|
269 |
prompt: str,
|
@@ -308,19 +329,7 @@ def generate_video(
|
|
308 |
- fps (int): The frames per second for the generated video.
|
309 |
"""
|
310 |
|
311 |
-
|
312 |
-
original_prepare_latents = CogVideoXImageToVideoPipeline.prepare_latents
|
313 |
-
|
314 |
-
def prepare_latents_with_device(self, *args, **kwargs):
|
315 |
-
result = original_prepare_latents(self, *args, **kwargs)
|
316 |
-
# Ensure returned tensors are moved to the correct device
|
317 |
-
if isinstance(result, tuple):
|
318 |
-
result = tuple(t.to(self.device) if isinstance(t, torch.Tensor) else t for t in result)
|
319 |
-
elif isinstance(result, torch.Tensor):
|
320 |
-
result = result.to(self.device)
|
321 |
-
return result
|
322 |
-
|
323 |
-
CogVideoXImageToVideoPipeline.prepare_latents = types.MethodType(prepare_latents_with_device, CogVideoXImageToVideoPipeline)
|
324 |
|
325 |
print("at generate video", flush=True)
|
326 |
local_rank = 'cuda'
|
|
|
264 |
frame_list.append(Image.fromarray(frame))
|
265 |
|
266 |
export_to_video(frame_list, filename, fps=fps)
|
267 |
+
|
268 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
269 |
+
|
270 |
+
def patch_prepare_latents_safe():
|
271 |
+
def new_prepare_latents(self, image, num_frames, height, width, batch_size, dtype, generator, do_classifier_free_guidance=False):
|
272 |
+
image_latents = self.vae.encode(image.to(self.device, dtype=dtype)).latent_dist.sample()
|
273 |
+
image_latents = image_latents * self.vae.config.scaling_factor
|
274 |
+
|
275 |
+
if image_latents.shape[2] != num_frames:
|
276 |
+
latent_padding = torch.zeros(
|
277 |
+
(image_latents.shape[0], num_frames - image_latents.shape[2], image_latents.shape[3], image_latents.shape[4]),
|
278 |
+
device=image_latents.device, dtype=image_latents.dtype
|
279 |
+
)
|
280 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
281 |
+
|
282 |
+
noise = torch.randn_like(image_latents, generator=generator)
|
283 |
+
latents = noise.to(self.device, dtype=dtype)
|
284 |
+
|
285 |
+
return latents, image_latents.to(self.device, dtype=dtype)
|
286 |
+
|
287 |
+
CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
|
288 |
|
289 |
def generate_video(
|
290 |
prompt: str,
|
|
|
329 |
- fps (int): The frames per second for the generated video.
|
330 |
"""
|
331 |
|
332 |
+
patch_prepare_latents_safe()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
print("at generate video", flush=True)
|
335 |
local_rank = 'cuda'
|