{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "82a43bc9-5339-4fa2-a7b4-e3ad0bc7a54a", "metadata": {}, "source": [ "# Image-to-Video synthesis with AnimateAnyone and OpenVINO\n", "\n", "\n", "[AnimateAnyone](https://arxiv.org/pdf/2311.17117.pdf) tackles the task of generating animation sequences from a single character image. It builds upon diffusion models pre-trained on vast character image datasets.\n", "\n", "The core of AnimateAnyone is a diffusion model pre-trained on a massive dataset of character images. This model learns the underlying character representation and distribution, allowing for realistic and diverse character animation.\n", "To capture the specific details and characteristics of the input character image, AnimateAnyone incorporates a ReferenceNet module. This module acts like an attention mechanism, focusing on the input image and guiding the animation process to stay consistent with the original character's appearance. AnimateAnyone enables control over the character's pose during animation. This might involve using techniques like parametric pose embedding or direct pose vector input, allowing for the creation of various character actions and movements. To ensure smooth transitions and temporal coherence throughout the animation sequence, AnimateAnyone incorporates temporal modeling techniques. This may involve recurrent architectures like LSTMs or transformers that capture the temporal dependencies between video frames.\n", "\n", "Overall, AnimateAnyone combines a powerful pre-trained diffusion model with a character-specific attention mechanism (ReferenceNet), pose guidance, and temporal modeling to achieve controllable, high-fidelity character animation from a single image.\n", "\n", "Learn more in [GitHub repo](https://github.com/MooreThreads/Moore-AnimateAnyone) and [paper](https://arxiv.org/pdf/2311.17117.pdf).\n", "\n", "
! WARNING !
\n", "\n",
" This tutorial requires at least 96 GB of RAM for model conversion and 40 GB for inference. Changing the values of HEIGHT
, WIDTH
and VIDEO_LENGTH
variables will change the memory consumption but will also affect accuracy.\n",
"
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not VAE_ENCODER_PATH.exists():\n", " class VaeEncoder(torch.nn.Module):\n", " def __init__(self, vae):\n", " super().__init__()\n", " self.vae = vae\n", " \n", " def forward(self, x):\n", " return self.vae.encode(x).latent_dist.mean\n", " vae.eval()\n", " with torch.no_grad():\n", " vae_encoder = ov.convert_model(VaeEncoder(vae), example_input=torch.zeros(1,3,512,448))\n", " vae_encoder = nncf.compress_weights(vae_encoder)\n", " ov.save_model(vae_encoder, VAE_ENCODER_PATH)\n", " del vae_encoder\n", " cleanup_torchscript_cache()" ] }, { "cell_type": "code", "execution_count": 14, "id": "c8f87c8d-2538-478a-8f8f-7fa1f498dcb2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:Statistics of the bitwidth distribution:\n", "+--------------+---------------------------+-----------------------------------+\n", "| Num bits (N) | % all parameters (layers) | % ratio-defining parameters |\n", "| | | (layers) |\n", "+==============+===========================+===================================+\n", "| 8 | 100% (40 / 40) | 100% (40 / 40) |\n", "+--------------+---------------------------+-----------------------------------+\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ce3d8670ebdc41ecaeae193539a53b8a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not VAE_DECODER_PATH.exists():\n", " class VaeDecoder(torch.nn.Module):\n", " def __init__(self, vae):\n", " super().__init__()\n", " self.vae = vae\n", " \n", " def forward(self, z):\n", " return self.vae.decode(z).sample\n", " vae.eval()\n", " with torch.no_grad():\n", " vae_decoder = ov.convert_model(VaeDecoder(vae), example_input=torch.zeros(1,4,HEIGHT//8,WIDTH//8))\n", " vae_decoder = nncf.compress_weights(vae_decoder)\n", " ov.save_model(vae_decoder, VAE_DECODER_PATH)\n", " del vae_decoder\n", " cleanup_torchscript_cache()\n", "del vae\n", "gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "be82f4e8-0e1b-406a-93a2-d8c8c23c7797", "metadata": {}, "source": [ "### Reference UNet\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Pipeline extracts reference attention features from all transformer blocks inside Reference UNet model. We call the original forward pass to obtain shapes of the outputs as they will be used in the next pipeline step." ] }, { "cell_type": "code", "execution_count": 15, "id": "9a31c54b-b6e9-41b1-80f6-c06a1bae52ae", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:Statistics of the bitwidth distribution:\n", "+--------------+---------------------------+-----------------------------------+\n", "| Num bits (N) | % all parameters (layers) | % ratio-defining parameters |\n", "| | | (layers) |\n", "+==============+===========================+===================================+\n", "| 8 | 100% (270 / 270) | 100% (270 / 270) |\n", "+--------------+---------------------------+-----------------------------------+\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b31aab4251de44e8bb59a533bb1cb9e4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not REFERENCE_UNET_PATH.exists():\n", " class ReferenceUNetWrapper(torch.nn.Module):\n", " def __init__(self, reference_unet):\n", " super().__init__()\n", " self.reference_unet = reference_unet\n", " \n", " def forward(self, sample, timestep, encoder_hidden_states):\n", " return self.reference_unet(sample, timestep, encoder_hidden_states, return_dict=False)[1]\n", " \n", " sample = torch.zeros(2, 4, HEIGHT // 8, WIDTH // 8)\n", " timestep = torch.tensor(0)\n", " encoder_hidden_states = torch.zeros(2, 1, 768)\n", " reference_unet.eval()\n", " with torch.no_grad():\n", " wrapper = ReferenceUNetWrapper(reference_unet)\n", " example_input = (sample, timestep, encoder_hidden_states)\n", " ref_features_shapes = {k: v.shape for k, v in wrapper(*example_input).items()}\n", " ov_reference_unet = ov.convert_model(\n", " wrapper,\n", " example_input=example_input,\n", " )\n", " ov_reference_unet = nncf.compress_weights(ov_reference_unet)\n", " ov.save_model(ov_reference_unet, REFERENCE_UNET_PATH)\n", " del ov_reference_unet\n", " del wrapper\n", " cleanup_torchscript_cache()\n", "del reference_unet\n", "gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7ddd784b-c54b-45f4-9797-7b4634e90ec0", "metadata": {}, "source": [ "### Denoising UNet\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Denoising UNet is the main part of all diffusion pipelines. This model consumes the majority of memory, so we need to reduce its size as much as possible.\n", "\n", "Here we make all shapes static meaning that the size of the video will be constant.\n", "\n", "Also, we use the `ref_features` input with the same tensor shapes as output of [Reference UNet](#Reference-UNet) model on the previous step." ] }, { "cell_type": "code", "execution_count": 16, "id": "e95a7dbd-6235-44f0-81d7-1898b51839c9", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:Statistics of the bitwidth distribution:\n", "+--------------+---------------------------+-----------------------------------+\n", "| Num bits (N) | % all parameters (layers) | % ratio-defining parameters |\n", "| | | (layers) |\n", "+==============+===========================+===================================+\n", "| 8 | 100% (534 / 534) | 100% (534 / 534) |\n", "+--------------+---------------------------+-----------------------------------+\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "776d8e8cb44446428e50db32618df935", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not DENOISING_UNET_PATH.exists():\n", " class DenoisingUNetWrapper(torch.nn.Module):\n", " def __init__(self, denoising_unet):\n", " super().__init__()\n", " self.denoising_unet = denoising_unet\n", " \n", " def forward(\n", " self,\n", " sample,\n", " timestep,\n", " encoder_hidden_states,\n", " pose_cond_fea,\n", " ref_features\n", " ):\n", " return self.denoising_unet(\n", " sample,\n", " timestep,\n", " encoder_hidden_states,\n", " ref_features,\n", " pose_cond_fea=pose_cond_fea,\n", " return_dict=False)\n", "\n", " example_input = {\n", " \"sample\": torch.zeros(2, 4, VIDEO_LENGTH, HEIGHT // 8, WIDTH // 8),\n", " \"timestep\": torch.tensor(999),\n", " \"encoder_hidden_states\": torch.zeros(2,1,768),\n", " \"pose_cond_fea\": torch.zeros(2, 320, VIDEO_LENGTH, HEIGHT // 8, WIDTH // 8),\n", " \"ref_features\": {k: torch.zeros(shape) for k, shape in ref_features_shapes.items()}\n", " }\n", " \n", " denoising_unet.eval()\n", " with torch.no_grad():\n", " ov_denoising_unet = ov.convert_model(\n", " DenoisingUNetWrapper(denoising_unet),\n", " example_input=tuple(example_input.values())\n", " )\n", " ov_denoising_unet.inputs[0].get_node().set_partial_shape(ov.PartialShape((2, 4, VIDEO_LENGTH, HEIGHT // 8, WIDTH // 8)))\n", " ov_denoising_unet.inputs[2].get_node().set_partial_shape(ov.PartialShape((2, 1, 768)))\n", " ov_denoising_unet.inputs[3].get_node().set_partial_shape(ov.PartialShape((2, 320, VIDEO_LENGTH, HEIGHT // 8, WIDTH // 8)))\n", " for ov_input, shape in zip(ov_denoising_unet.inputs[4:], ref_features_shapes.values()):\n", " ov_input.get_node().set_partial_shape(ov.PartialShape(shape))\n", " ov_input.get_node().set_element_type(ov.Type.f32)\n", " ov_denoising_unet.validate_nodes_and_infer_types()\n", " ov_denoising_unet = nncf.compress_weights(ov_denoising_unet)\n", " ov.save_model(ov_denoising_unet, DENOISING_UNET_PATH)\n", " del ov_denoising_unet\n", " cleanup_torchscript_cache()\n", "del denoising_unet\n", "gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "06054cfc-690c-4aaf-a05c-7df9d7ad08d2", "metadata": {}, "source": [ "### Pose Guider\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "To ensure pose controllability, a lightweight pose guider is devised to efficiently integrate pose control signals into the denoising process." ] }, { "cell_type": "code", "execution_count": 17, "id": "d4cf6c05-5326-49d0-9e4e-fa9b2081445f", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:Statistics of the bitwidth distribution:\n", "+--------------+---------------------------+-----------------------------------+\n", "| Num bits (N) | % all parameters (layers) | % ratio-defining parameters |\n", "| | | (layers) |\n", "+==============+===========================+===================================+\n", "| 8 | 100% (8 / 8) | 100% (8 / 8) |\n", "+--------------+---------------------------+-----------------------------------+\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "36ac12f0ac654a90b866dd07b483eeef", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not POSE_GUIDER_PATH.exists():\n", " pose_guider.eval()\n", " with torch.no_grad():\n", " ov_pose_guider = ov.convert_model(pose_guider, example_input=torch.zeros(1, 3, VIDEO_LENGTH, HEIGHT, WIDTH))\n", " ov_pose_guider = nncf.compress_weights(ov_pose_guider)\n", " ov.save_model(ov_pose_guider, POSE_GUIDER_PATH)\n", " del ov_pose_guider\n", " cleanup_torchscript_cache()\n", "del pose_guider\n", "gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c4db1630-95d4-4cea-927f-fe1d1c259597", "metadata": {}, "source": [ "### Image Encoder\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Pipeline uses CLIP image encoder to generate encoder hidden states required for both reference and denoising UNets." ] }, { "cell_type": "code", "execution_count": 18, "id": "b7331636-b124-4038-9553-2292de72f13e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/itrushkin/.virtualenvs/test/lib/python3.10/site-packages/transformers/modeling_utils.py:4225: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:Statistics of the bitwidth distribution:\n", "+--------------+---------------------------+-----------------------------------+\n", "| Num bits (N) | % all parameters (layers) | % ratio-defining parameters |\n", "| | | (layers) |\n", "+==============+===========================+===================================+\n", "| 8 | 100% (146 / 146) | 100% (146 / 146) |\n", "+--------------+---------------------------+-----------------------------------+\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4bfabe55e887446fa5ef8c8fea37b33e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not IMAGE_ENCODER_PATH.exists():\n", " image_enc.eval()\n", " with torch.no_grad():\n", " ov_image_encoder = ov.convert_model(image_enc, example_input=torch.zeros(1, 3, 224, 224), input=(1, 3, 224, 224))\n", " ov_image_encoder = nncf.compress_weights(ov_image_encoder)\n", " ov.save_model(ov_image_encoder, IMAGE_ENCODER_PATH)\n", " del ov_image_encoder\n", " cleanup_torchscript_cache()\n", "del image_enc\n", "gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4e059780-29fc-4fe0-9376-c83573695fbe", "metadata": {}, "source": [ "## Inference\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "We inherit from the original pipeline modifying the calls to our models to match OpenVINO format." ] }, { "cell_type": "code", "execution_count": 19, "id": "35176d47-9a79-49dd-a61f-75892dba8d3d", "metadata": {}, "outputs": [], "source": [ "core = ov.Core()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7a28383f-8249-4c24-bc4c-8dcb0f75461c", "metadata": {}, "source": [ "### Select inference device\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "For starting work, please select inference device from dropdown list." ] }, { "cell_type": "code", "execution_count": 20, "id": "6f43558f-f244-43b7-9c82-c47b5b0bce23", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3e91c2a792224a4983cb2758f093b775", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dropdown(description='Device:', index=5, options=('CPU', 'GPU.0', 'GPU.1', 'GPU.2', 'GPU.3', 'AUTO'), value='A…" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = widgets.Dropdown(\n", " options=core.available_devices + [\"AUTO\"],\n", " value=\"AUTO\",\n", " description=\"Device:\",\n", " disabled=False,\n", ")\n", "\n", "device" ] }, { "cell_type": "code", "execution_count": 21, "id": "2c3cf752-cd75-4063-b56a-cade1894f926", "metadata": {}, "outputs": [], "source": [ "class OVPose2VideoPipeline(Pose2VideoPipeline):\n", " def __init__(\n", " self,\n", " vae_encoder_path=VAE_ENCODER_PATH,\n", " vae_decoder_path=VAE_DECODER_PATH,\n", " image_encoder_path=IMAGE_ENCODER_PATH,\n", " reference_unet_path=REFERENCE_UNET_PATH,\n", " denoising_unet_path=DENOISING_UNET_PATH,\n", " pose_guider_path=POSE_GUIDER_PATH,\n", " device=device.value,\n", " ):\n", " self.vae_encoder = core.compile_model(vae_encoder_path, device)\n", " self.vae_decoder = core.compile_model(vae_decoder_path, device)\n", " self.image_encoder = core.compile_model(image_encoder_path, device)\n", " self.reference_unet = core.compile_model(reference_unet_path, device)\n", " self.denoising_unet = core.compile_model(denoising_unet_path, device)\n", " self.pose_guider = core.compile_model(pose_guider_path, device)\n", " self.scheduler = DDIMScheduler(**OmegaConf.to_container(infer_config.noise_scheduler_kwargs))\n", "\n", " self.vae_scale_factor = 8\n", " self.clip_image_processor = CLIPImageProcessor()\n", " self.ref_image_processor = VaeImageProcessor(do_convert_rgb=True)\n", " self.cond_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False)\n", "\n", " def decode_latents(self, latents):\n", " video_length = latents.shape[2]\n", " latents = 1 / 0.18215 * latents\n", " latents = rearrange(latents, \"b c f h w -> (b f) c h w\")\n", " # video = self.vae.decode(latents).sample\n", " video = []\n", " for frame_idx in tqdm(range(latents.shape[0])):\n", " video.append(torch.from_numpy(self.vae_decoder(latents[frame_idx : frame_idx + 1])[0]))\n", " video = torch.cat(video)\n", " video = rearrange(video, \"(b f) c h w -> b c f h w\", f=video_length)\n", " video = (video / 2 + 0.5).clamp(0, 1)\n", " # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16\n", " video = video.cpu().float().numpy()\n", " return video\n", "\n", " def __call__(\n", " self,\n", " ref_image,\n", " pose_images,\n", " width,\n", " height,\n", " video_length,\n", " num_inference_steps=30,\n", " guidance_scale=3.5,\n", " num_images_per_prompt=1,\n", " eta: float = 0.0,\n", " generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n", " output_type: Optional[str] = \"tensor\",\n", " callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n", " callback_steps: Optional[int] = 1,\n", " context_schedule=\"uniform\",\n", " context_frames=24,\n", " context_stride=1,\n", " context_overlap=4,\n", " context_batch_size=1,\n", " interpolation_factor=1,\n", " **kwargs,\n", " ):\n", " do_classifier_free_guidance = guidance_scale > 1.0\n", "\n", " # Prepare timesteps\n", " self.scheduler.set_timesteps(num_inference_steps)\n", " timesteps = self.scheduler.timesteps\n", "\n", " batch_size = 1\n", "\n", " # Prepare clip image embeds\n", " clip_image = self.clip_image_processor.preprocess(ref_image.resize((224, 224)), return_tensors=\"pt\").pixel_values\n", " clip_image_embeds = self.image_encoder(clip_image)[\"image_embeds\"]\n", " clip_image_embeds = torch.from_numpy(clip_image_embeds)\n", " encoder_hidden_states = clip_image_embeds.unsqueeze(1)\n", " uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)\n", "\n", " if do_classifier_free_guidance:\n", " encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0)\n", "\n", " latents = self.prepare_latents(\n", " batch_size * num_images_per_prompt,\n", " 4,\n", " width,\n", " height,\n", " video_length,\n", " clip_image_embeds.dtype,\n", " torch.device(\"cpu\"),\n", " generator,\n", " )\n", "\n", " # Prepare extra step kwargs.\n", " extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n", "\n", " # Prepare ref image latents\n", " ref_image_tensor = self.ref_image_processor.preprocess(ref_image, height=height, width=width) # (bs, c, width, height)\n", " ref_image_latents = self.vae_encoder(ref_image_tensor)[0]\n", " ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)\n", " ref_image_latents = torch.from_numpy(ref_image_latents)\n", "\n", " # Prepare a list of pose condition images\n", " pose_cond_tensor_list = []\n", " for pose_image in pose_images:\n", " pose_cond_tensor = self.cond_image_processor.preprocess(pose_image, height=height, width=width)\n", " pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)\n", " pose_cond_tensor_list.append(pose_cond_tensor)\n", " pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=2) # (bs, c, t, h, w)\n", " pose_fea = self.pose_guider(pose_cond_tensor)[0]\n", " pose_fea = torch.from_numpy(pose_fea)\n", "\n", " context_scheduler = get_context_scheduler(context_schedule)\n", "\n", " # denoising loop\n", " num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n", " with self.progress_bar(total=num_inference_steps) as progress_bar:\n", " for i, t in enumerate(timesteps):\n", " noise_pred = torch.zeros(\n", " (\n", " latents.shape[0] * (2 if do_classifier_free_guidance else 1),\n", " *latents.shape[1:],\n", " ),\n", " device=latents.device,\n", " dtype=latents.dtype,\n", " )\n", " counter = torch.zeros(\n", " (1, 1, latents.shape[2], 1, 1),\n", " device=latents.device,\n", " dtype=latents.dtype,\n", " )\n", "\n", " # 1. Forward reference image\n", " if i == 0:\n", " ref_features = self.reference_unet(\n", " (\n", " ref_image_latents.repeat((2 if do_classifier_free_guidance else 1), 1, 1, 1),\n", " torch.zeros_like(t),\n", " # t,\n", " encoder_hidden_states,\n", " )\n", " ).values()\n", "\n", " context_queue = list(\n", " context_scheduler(\n", " 0,\n", " num_inference_steps,\n", " latents.shape[2],\n", " context_frames,\n", " context_stride,\n", " 0,\n", " )\n", " )\n", " num_context_batches = math.ceil(len(context_queue) / context_batch_size)\n", "\n", " context_queue = list(\n", " context_scheduler(\n", " 0,\n", " num_inference_steps,\n", " latents.shape[2],\n", " context_frames,\n", " context_stride,\n", " context_overlap,\n", " )\n", " )\n", "\n", " num_context_batches = math.ceil(len(context_queue) / context_batch_size)\n", " global_context = []\n", " for i in range(num_context_batches):\n", " global_context.append(context_queue[i * context_batch_size : (i + 1) * context_batch_size])\n", "\n", " for context in global_context:\n", " # 3.1 expand the latents if we are doing classifier free guidance\n", " latent_model_input = torch.cat([latents[:, :, c] for c in context]).repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)\n", " latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n", " b, c, f, h, w = latent_model_input.shape\n", " latent_pose_input = torch.cat([pose_fea[:, :, c] for c in context]).repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)\n", "\n", " pred = self.denoising_unet(\n", " (\n", " latent_model_input,\n", " t,\n", " encoder_hidden_states[:b],\n", " latent_pose_input,\n", " *ref_features,\n", " )\n", " )[0]\n", "\n", " for j, c in enumerate(context):\n", " noise_pred[:, :, c] = noise_pred[:, :, c] + pred\n", " counter[:, :, c] = counter[:, :, c] + 1\n", "\n", " # perform guidance\n", " if do_classifier_free_guidance:\n", " noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)\n", " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n", "\n", " latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n", "\n", " if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n", " progress_bar.update()\n", " if callback is not None and i % callback_steps == 0:\n", " step_idx = i // getattr(self.scheduler, \"order\", 1)\n", " callback(step_idx, t, latents)\n", "\n", " if interpolation_factor > 0:\n", " latents = self.interpolate_latents(latents, interpolation_factor, latents.device)\n", " # Post-processing\n", " images = self.decode_latents(latents) # (b, c, f, h, w)\n", "\n", " # Convert to tensor\n", " if output_type == \"tensor\":\n", " images = torch.from_numpy(images)\n", "\n", " return images" ] }, { "cell_type": "code", "execution_count": 22, "id": "ca762050-692d-48ed-b31b-73130e712784", "metadata": {}, "outputs": [], "source": [ "pipe = OVPose2VideoPipeline()" ] }, { "cell_type": "code", "execution_count": 23, "id": "9a34c91c-20f5-4e5f-aed2-f3dc3c226e79", "metadata": {}, "outputs": [], "source": [ "pose_images = read_frames(\"Moore-AnimateAnyone/configs/inference/pose_videos/anyone-video-2_kps.mp4\")\n", "src_fps = get_fps(\"Moore-AnimateAnyone/configs/inference/pose_videos/anyone-video-2_kps.mp4\")\n", "ref_image = Image.open(\"Moore-AnimateAnyone/configs/inference/ref_images/anyone-5.png\").convert(\"RGB\")\n", "pose_list = []\n", "for pose_image_pil in pose_images[:VIDEO_LENGTH]:\n", " pose_list.append(pose_image_pil)" ] }, { "cell_type": "code", "execution_count": 24, "id": "a6636887-eb52-494f-86ee-d3da25746c14", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "de88200bdd334dafa57361beb0bcd027", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c5f6b29521e24285822c4d072fb3f554", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/24 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "video = pipe(\n", " ref_image,\n", " pose_list,\n", " width=WIDTH,\n", " height=HEIGHT,\n", " video_length=VIDEO_LENGTH,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b35f75ae-818e-483b-8e83-7e580d728eaf", "metadata": {}, "source": [ "## Video post-processing\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 25, "id": "86cff8b7-4723-4520-ad2e-bbca6f0c76d7", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "new_h, new_w = video.shape[-2:]\n", "pose_transform = transforms.Compose([transforms.Resize((new_h, new_w)), transforms.ToTensor()])\n", "pose_tensor_list = []\n", "for pose_image_pil in pose_images[:VIDEO_LENGTH]:\n", " pose_tensor_list.append(pose_transform(pose_image_pil))\n", "\n", "ref_image_tensor = pose_transform(ref_image) # (c, h, w)\n", "ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)\n", "ref_image_tensor = repeat(ref_image_tensor, \"b c f h w -> b c (repeat f) h w\", repeat=VIDEO_LENGTH)\n", "pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)\n", "pose_tensor = pose_tensor.transpose(0, 1)\n", "pose_tensor = pose_tensor.unsqueeze(0)\n", "video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)\n", "\n", "save_dir = Path(\"./output\")\n", "save_dir.mkdir(parents=True, exist_ok=True)\n", "date_str = datetime.now().strftime(\"%Y%m%d\")\n", "time_str = datetime.now().strftime(\"%H%M\")\n", "out_path = save_dir / f\"{date_str}T{time_str}.mp4\"\n", "save_videos_grid(\n", " video,\n", " str(out_path),\n", " n_rows=3,\n", " fps=src_fps,\n", ")" ] }, { "cell_type": "code", "execution_count": 26, "id": "08c82ebc-55b7-4899-a34d-dca208567125", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "