Update inference/flovd_demo.py
Browse files- inference/flovd_demo.py +24 -7
inference/flovd_demo.py
CHANGED
@@ -264,28 +264,45 @@ def save_flow_warped_video(image, flow, filename, fps=16):
|
|
264 |
frame_list.append(Image.fromarray(frame))
|
265 |
|
266 |
export_to_video(frame_list, filename, fps=fps)
|
|
|
267 |
|
268 |
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
269 |
-
|
270 |
def patch_prepare_latents_safe():
|
271 |
-
def new_prepare_latents(
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
image_latents = image_latents * self.vae.config.scaling_factor
|
274 |
|
|
|
275 |
if image_latents.shape[2] != num_frames:
|
276 |
latent_padding = torch.zeros(
|
277 |
(image_latents.shape[0], num_frames - image_latents.shape[2], image_latents.shape[3], image_latents.shape[4]),
|
278 |
-
device=image_latents.device,
|
|
|
279 |
)
|
280 |
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
281 |
|
282 |
-
|
283 |
-
|
|
|
284 |
|
285 |
-
return latents, image_latents.to(
|
286 |
|
|
|
287 |
CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
|
288 |
|
|
|
289 |
def generate_video(
|
290 |
prompt: str,
|
291 |
fvsm_path: str,
|
|
|
264 |
frame_list.append(Image.fromarray(frame))
|
265 |
|
266 |
export_to_video(frame_list, filename, fps=fps)
|
267 |
+
|
268 |
|
269 |
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
|
|
270 |
def patch_prepare_latents_safe():
|
271 |
+
def new_prepare_latents(
|
272 |
+
self,
|
273 |
+
image,
|
274 |
+
batch_size,
|
275 |
+
latent_channels,
|
276 |
+
num_frames,
|
277 |
+
height,
|
278 |
+
width,
|
279 |
+
dtype,
|
280 |
+
device,
|
281 |
+
generator,
|
282 |
+
latents=None,
|
283 |
+
):
|
284 |
+
image_latents = self.vae.encode(image.to(device, dtype=dtype)).latent_dist.sample()
|
285 |
image_latents = image_latents * self.vae.config.scaling_factor
|
286 |
|
287 |
+
# Pad temporal dimension if needed
|
288 |
if image_latents.shape[2] != num_frames:
|
289 |
latent_padding = torch.zeros(
|
290 |
(image_latents.shape[0], num_frames - image_latents.shape[2], image_latents.shape[3], image_latents.shape[4]),
|
291 |
+
device=image_latents.device,
|
292 |
+
dtype=image_latents.dtype
|
293 |
)
|
294 |
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
295 |
|
296 |
+
if latents is None:
|
297 |
+
noise = torch.randn_like(image_latents, generator=generator)
|
298 |
+
latents = noise.to(device=device, dtype=dtype)
|
299 |
|
300 |
+
return latents, image_latents.to(device, dtype=dtype)
|
301 |
|
302 |
+
from diffusers.pipelines.cogvideo.pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
|
303 |
CogVideoXImageToVideoPipeline.prepare_latents = new_prepare_latents
|
304 |
|
305 |
+
|
306 |
def generate_video(
|
307 |
prompt: str,
|
308 |
fvsm_path: str,
|