{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "82a43bc9-5339-4fa2-a7b4-e3ad0bc7a54a", "metadata": {}, "source": [ "# Image-to-Video synthesis with AnimateAnyone and OpenVINO\n", "![](./animate-anyone.gif)\n", "\n", "[AnimateAnyone](https://arxiv.org/pdf/2311.17117.pdf) tackles the task of generating animation sequences from a single character image. It builds upon diffusion models pre-trained on vast character image datasets.\n", "\n", "The core of AnimateAnyone is a diffusion model pre-trained on a massive dataset of character images. This model learns the underlying character representation and distribution, allowing for realistic and diverse character animation.\n", "To capture the specific details and characteristics of the input character image, AnimateAnyone incorporates a ReferenceNet module. This module acts like an attention mechanism, focusing on the input image and guiding the animation process to stay consistent with the original character's appearance. AnimateAnyone enables control over the character's pose during animation. This might involve using techniques like parametric pose embedding or direct pose vector input, allowing for the creation of various character actions and movements. To ensure smooth transitions and temporal coherence throughout the animation sequence, AnimateAnyone incorporates temporal modeling techniques. This may involve recurrent architectures like LSTMs or transformers that capture the temporal dependencies between video frames.\n", "\n", "Overall, AnimateAnyone combines a powerful pre-trained diffusion model with a character-specific attention mechanism (ReferenceNet), pose guidance, and temporal modeling to achieve controllable, high-fidelity character animation from a single image.\n", "\n", "Learn more in [GitHub repo](https://github.com/MooreThreads/Moore-AnimateAnyone) and [paper](https://arxiv.org/pdf/2311.17117.pdf).\n", "\n", "
\n", "

! WARNING !

\n", "

\n", " This tutorial requires at least 96 GB of RAM for model conversion and 40 GB for inference. Changing the values of HEIGHT, WIDTH and VIDEO_LENGTH variables will change the memory consumption but will also affect accuracy.\n", "

\n", "
\n", "\n", "#### Table of contents:\n", "\n", "- [Prerequisites](#Prerequisites)\n", "- [Prepare base model](#Prepare-base-model)\n", "- [Prepare image encoder](#Prepare-image-encoder)\n", "- [Download weights](#Download-weights)\n", "- [Initialize models](#Initialize-models)\n", "- [Load pretrained weights](#Load-pretrained-weights)\n", "- [Convert model to OpenVINO IR](#Convert-model-to-OpenVINO-IR)\n", " - [VAE](#VAE)\n", " - [Reference UNet](#Reference-UNet)\n", " - [Denoising UNet](#Denoising-UNet)\n", " - [Pose Guider](#Pose-Guider)\n", " - [Image Encoder](#Image-Encoder)\n", "- [Inference](#Inference)\n", "- [Video post-processing](#Video-post-processing)\n", "- [Interactive inference](#Interactive-inference)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "33e0d739-01c9-4a16-873c-2fccc046d3a9", "metadata": {}, "source": [ "## Prerequisites\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 1, "id": "eb8ce0dc-7e5d-4661-a7e7-378bb9e67994", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: typer 0.12.3 does not provide the extra 'all'\u001b[0m\u001b[33m\n", "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "from pathlib import Path\n", "import requests\n", "\n", "\n", "REPO_PATH = Path(\"Moore-AnimateAnyone\")\n", "if not REPO_PATH.exists():\n", " !git clone -q \"https://github.com/itrushkin/Moore-AnimateAnyone.git\"\n", "%pip install -q \"torch>=2.1\" torchvision einops omegaconf \"diffusers<=0.24\" transformers av accelerate \"openvino>=2024.0\" \"nncf>=2.9.0\" \"gradio>=4.19\" --extra-index-url \"https://download.pytorch.org/whl/cpu\"\n", "import sys\n", "\n", "sys.path.insert(0, str(REPO_PATH.resolve()))\n", "r = requests.get(\n", " url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/skip_kernel_extension.py\",\n", ")\n", "open(\"skip_kernel_extension.py\", \"w\").write(r.text)\n", "%load_ext skip_kernel_extension" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9f595c7b-1fdb-4a1e-970d-ad5334e47d5c", "metadata": {}, "source": [ "Note that we clone a fork of original repo with tweaked forward methods." ] }, { "cell_type": "code", "execution_count": 2, "id": "a2e1b26e-9142-47fa-a15e-5930ef8f306e", "metadata": {}, "outputs": [], "source": [ "MODEL_DIR = Path(\"models\")\n", "VAE_ENCODER_PATH = MODEL_DIR / \"vae_encoder.xml\"\n", "VAE_DECODER_PATH = MODEL_DIR / \"vae_decoder.xml\"\n", "REFERENCE_UNET_PATH = MODEL_DIR / \"reference_unet.xml\"\n", "DENOISING_UNET_PATH = MODEL_DIR / \"denoising_unet.xml\"\n", "POSE_GUIDER_PATH = MODEL_DIR / \"pose_guider.xml\"\n", "IMAGE_ENCODER_PATH = MODEL_DIR / \"image_encoder.xml\"\n", "\n", "WIDTH = 448\n", "HEIGHT = 512\n", "VIDEO_LENGTH = 24\n", "\n", "SHOULD_CONVERT = not all(\n", " p.exists()\n", " for p in [\n", " VAE_ENCODER_PATH,\n", " VAE_DECODER_PATH,\n", " REFERENCE_UNET_PATH,\n", " DENOISING_UNET_PATH,\n", " POSE_GUIDER_PATH,\n", " IMAGE_ENCODER_PATH,\n", " ]\n", ")" ] }, { "cell_type": "code", "execution_count": 3, "id": "917e01b7-43bc-4358-8752-cc862bd74758", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/itrushkin/.virtualenvs/test/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", " torch.utils._pytree._register_pytree_node(\n", "/home/itrushkin/.virtualenvs/test/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", " torch.utils._pytree._register_pytree_node(\n", "/home/itrushkin/.virtualenvs/test/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", " torch.utils._pytree._register_pytree_node(\n" ] } ], "source": [ "from datetime import datetime\n", "from typing import Optional, Union, List, Callable\n", "import math\n", "\n", "from PIL import Image\n", "import openvino as ov\n", "from torchvision import transforms\n", "from einops import repeat\n", "from tqdm.auto import tqdm\n", "from einops import rearrange\n", "from omegaconf import OmegaConf\n", "from diffusers import DDIMScheduler\n", "from diffusers.image_processor import VaeImageProcessor\n", "from transformers import CLIPImageProcessor\n", "import torch\n", "import gradio as gr\n", "import ipywidgets as widgets\n", "import numpy as np\n", "\n", "from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline\n", "from src.utils.util import get_fps, read_frames\n", "from src.utils.util import save_videos_grid\n", "from src.pipelines.context import get_context_scheduler" ] }, { "cell_type": "code", "execution_count": 4, "id": "4bcfe74c-1caf-404b-89c9-466c60e19aa7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, openvino\n" ] } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "from pathlib import PurePosixPath\n", "import gc\n", "import warnings\n", "\n", "from typing import Dict, Any\n", "from diffusers import AutoencoderKL\n", "from huggingface_hub import hf_hub_download, snapshot_download\n", "from transformers import CLIPVisionModelWithProjection\n", "import nncf\n", "\n", "from src.models.unet_2d_condition import UNet2DConditionModel\n", "from src.models.unet_3d import UNet3DConditionModel\n", "from src.models.pose_guider import PoseGuider" ] }, { "attachments": {}, "cell_type": "markdown", "id": "dfc4d86e-fe72-48b1-9af7-6b428935fe3a", "metadata": {}, "source": [ "## Prepare base model\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 5, "id": "ff7caead-ad5c-421f-8177-6afcbd4cde54", "metadata": {}, "outputs": [], "source": [ "%%skip not $SHOULD_CONVERT\n", "local_dir = Path(\"./pretrained_weights/stable-diffusion-v1-5\")\n", "local_dir.mkdir(parents=True, exist_ok=True)\n", "for hub_file in [\"unet/config.json\", \"unet/diffusion_pytorch_model.bin\"]:\n", " saved_path = local_dir / hub_file\n", " if saved_path.exists():\n", " continue\n", " hf_hub_download(\n", " repo_id=\"runwayml/stable-diffusion-v1-5\",\n", " subfolder=PurePosixPath(saved_path.parent.name),\n", " filename=PurePosixPath(saved_path.name),\n", " local_dir=local_dir,\n", " )" ] }, { "attachments": {}, "cell_type": "markdown", "id": "30022dee-d76a-4eba-8652-620ecde4a2f1", "metadata": {}, "source": [ "## Prepare image encoder\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 6, "id": "a50a8216-4edd-4237-bee4-e07b33758644", "metadata": {}, "outputs": [], "source": [ "%%skip not $SHOULD_CONVERT\n", "local_dir = Path(\"./pretrained_weights\")\n", "local_dir.mkdir(parents=True, exist_ok=True)\n", "for hub_file in [\"image_encoder/config.json\", \"image_encoder/pytorch_model.bin\"]:\n", " saved_path = local_dir / hub_file\n", " if saved_path.exists():\n", " continue\n", " hf_hub_download(\n", " repo_id=\"lambdalabs/sd-image-variations-diffusers\",\n", " subfolder=PurePosixPath(saved_path.parent.name),\n", " filename=PurePosixPath(saved_path.name),\n", " local_dir=local_dir,\n", " )" ] }, { "attachments": {}, "cell_type": "markdown", "id": "daed8698-c76a-4b59-aa13-165749c6a0db", "metadata": {}, "source": [ "## Download weights\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 7, "id": "c5013395-295a-4c7c-8302-2459c343de65", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4045ef358f1e4bbea93919ce15cff43a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 5 files: 0%| | 0/5 [00:00\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not VAE_ENCODER_PATH.exists():\n", " class VaeEncoder(torch.nn.Module):\n", " def __init__(self, vae):\n", " super().__init__()\n", " self.vae = vae\n", " \n", " def forward(self, x):\n", " return self.vae.encode(x).latent_dist.mean\n", " vae.eval()\n", " with torch.no_grad():\n", " vae_encoder = ov.convert_model(VaeEncoder(vae), example_input=torch.zeros(1,3,512,448))\n", " vae_encoder = nncf.compress_weights(vae_encoder)\n", " ov.save_model(vae_encoder, VAE_ENCODER_PATH)\n", " del vae_encoder\n", " cleanup_torchscript_cache()" ] }, { "cell_type": "code", "execution_count": 14, "id": "c8f87c8d-2538-478a-8f8f-7fa1f498dcb2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:Statistics of the bitwidth distribution:\n", "+--------------+---------------------------+-----------------------------------+\n", "| Num bits (N) | % all parameters (layers) | % ratio-defining parameters |\n", "| | | (layers) |\n", "+==============+===========================+===================================+\n", "| 8 | 100% (40 / 40) | 100% (40 / 40) |\n", "+--------------+---------------------------+-----------------------------------+\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ce3d8670ebdc41ecaeae193539a53b8a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not VAE_DECODER_PATH.exists():\n", " class VaeDecoder(torch.nn.Module):\n", " def __init__(self, vae):\n", " super().__init__()\n", " self.vae = vae\n", " \n", " def forward(self, z):\n", " return self.vae.decode(z).sample\n", " vae.eval()\n", " with torch.no_grad():\n", " vae_decoder = ov.convert_model(VaeDecoder(vae), example_input=torch.zeros(1,4,HEIGHT//8,WIDTH//8))\n", " vae_decoder = nncf.compress_weights(vae_decoder)\n", " ov.save_model(vae_decoder, VAE_DECODER_PATH)\n", " del vae_decoder\n", " cleanup_torchscript_cache()\n", "del vae\n", "gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "be82f4e8-0e1b-406a-93a2-d8c8c23c7797", "metadata": {}, "source": [ "### Reference UNet\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Pipeline extracts reference attention features from all transformer blocks inside Reference UNet model. We call the original forward pass to obtain shapes of the outputs as they will be used in the next pipeline step." ] }, { "cell_type": "code", "execution_count": 15, "id": "9a31c54b-b6e9-41b1-80f6-c06a1bae52ae", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:Statistics of the bitwidth distribution:\n", "+--------------+---------------------------+-----------------------------------+\n", "| Num bits (N) | % all parameters (layers) | % ratio-defining parameters |\n", "| | | (layers) |\n", "+==============+===========================+===================================+\n", "| 8 | 100% (270 / 270) | 100% (270 / 270) |\n", "+--------------+---------------------------+-----------------------------------+\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b31aab4251de44e8bb59a533bb1cb9e4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not REFERENCE_UNET_PATH.exists():\n", " class ReferenceUNetWrapper(torch.nn.Module):\n", " def __init__(self, reference_unet):\n", " super().__init__()\n", " self.reference_unet = reference_unet\n", " \n", " def forward(self, sample, timestep, encoder_hidden_states):\n", " return self.reference_unet(sample, timestep, encoder_hidden_states, return_dict=False)[1]\n", " \n", " sample = torch.zeros(2, 4, HEIGHT // 8, WIDTH // 8)\n", " timestep = torch.tensor(0)\n", " encoder_hidden_states = torch.zeros(2, 1, 768)\n", " reference_unet.eval()\n", " with torch.no_grad():\n", " wrapper = ReferenceUNetWrapper(reference_unet)\n", " example_input = (sample, timestep, encoder_hidden_states)\n", " ref_features_shapes = {k: v.shape for k, v in wrapper(*example_input).items()}\n", " ov_reference_unet = ov.convert_model(\n", " wrapper,\n", " example_input=example_input,\n", " )\n", " ov_reference_unet = nncf.compress_weights(ov_reference_unet)\n", " ov.save_model(ov_reference_unet, REFERENCE_UNET_PATH)\n", " del ov_reference_unet\n", " del wrapper\n", " cleanup_torchscript_cache()\n", "del reference_unet\n", "gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7ddd784b-c54b-45f4-9797-7b4634e90ec0", "metadata": {}, "source": [ "### Denoising UNet\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Denoising UNet is the main part of all diffusion pipelines. This model consumes the majority of memory, so we need to reduce its size as much as possible.\n", "\n", "Here we make all shapes static meaning that the size of the video will be constant.\n", "\n", "Also, we use the `ref_features` input with the same tensor shapes as output of [Reference UNet](#Reference-UNet) model on the previous step." ] }, { "cell_type": "code", "execution_count": 16, "id": "e95a7dbd-6235-44f0-81d7-1898b51839c9", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:Statistics of the bitwidth distribution:\n", "+--------------+---------------------------+-----------------------------------+\n", "| Num bits (N) | % all parameters (layers) | % ratio-defining parameters |\n", "| | | (layers) |\n", "+==============+===========================+===================================+\n", "| 8 | 100% (534 / 534) | 100% (534 / 534) |\n", "+--------------+---------------------------+-----------------------------------+\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "776d8e8cb44446428e50db32618df935", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not DENOISING_UNET_PATH.exists():\n", " class DenoisingUNetWrapper(torch.nn.Module):\n", " def __init__(self, denoising_unet):\n", " super().__init__()\n", " self.denoising_unet = denoising_unet\n", " \n", " def forward(\n", " self,\n", " sample,\n", " timestep,\n", " encoder_hidden_states,\n", " pose_cond_fea,\n", " ref_features\n", " ):\n", " return self.denoising_unet(\n", " sample,\n", " timestep,\n", " encoder_hidden_states,\n", " ref_features,\n", " pose_cond_fea=pose_cond_fea,\n", " return_dict=False)\n", "\n", " example_input = {\n", " \"sample\": torch.zeros(2, 4, VIDEO_LENGTH, HEIGHT // 8, WIDTH // 8),\n", " \"timestep\": torch.tensor(999),\n", " \"encoder_hidden_states\": torch.zeros(2,1,768),\n", " \"pose_cond_fea\": torch.zeros(2, 320, VIDEO_LENGTH, HEIGHT // 8, WIDTH // 8),\n", " \"ref_features\": {k: torch.zeros(shape) for k, shape in ref_features_shapes.items()}\n", " }\n", " \n", " denoising_unet.eval()\n", " with torch.no_grad():\n", " ov_denoising_unet = ov.convert_model(\n", " DenoisingUNetWrapper(denoising_unet),\n", " example_input=tuple(example_input.values())\n", " )\n", " ov_denoising_unet.inputs[0].get_node().set_partial_shape(ov.PartialShape((2, 4, VIDEO_LENGTH, HEIGHT // 8, WIDTH // 8)))\n", " ov_denoising_unet.inputs[2].get_node().set_partial_shape(ov.PartialShape((2, 1, 768)))\n", " ov_denoising_unet.inputs[3].get_node().set_partial_shape(ov.PartialShape((2, 320, VIDEO_LENGTH, HEIGHT // 8, WIDTH // 8)))\n", " for ov_input, shape in zip(ov_denoising_unet.inputs[4:], ref_features_shapes.values()):\n", " ov_input.get_node().set_partial_shape(ov.PartialShape(shape))\n", " ov_input.get_node().set_element_type(ov.Type.f32)\n", " ov_denoising_unet.validate_nodes_and_infer_types()\n", " ov_denoising_unet = nncf.compress_weights(ov_denoising_unet)\n", " ov.save_model(ov_denoising_unet, DENOISING_UNET_PATH)\n", " del ov_denoising_unet\n", " cleanup_torchscript_cache()\n", "del denoising_unet\n", "gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "06054cfc-690c-4aaf-a05c-7df9d7ad08d2", "metadata": {}, "source": [ "### Pose Guider\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "To ensure pose controllability, a lightweight pose guider is devised to efficiently integrate pose control signals into the denoising process." ] }, { "cell_type": "code", "execution_count": 17, "id": "d4cf6c05-5326-49d0-9e4e-fa9b2081445f", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:Statistics of the bitwidth distribution:\n", "+--------------+---------------------------+-----------------------------------+\n", "| Num bits (N) | % all parameters (layers) | % ratio-defining parameters |\n", "| | | (layers) |\n", "+==============+===========================+===================================+\n", "| 8 | 100% (8 / 8) | 100% (8 / 8) |\n", "+--------------+---------------------------+-----------------------------------+\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "36ac12f0ac654a90b866dd07b483eeef", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not POSE_GUIDER_PATH.exists():\n", " pose_guider.eval()\n", " with torch.no_grad():\n", " ov_pose_guider = ov.convert_model(pose_guider, example_input=torch.zeros(1, 3, VIDEO_LENGTH, HEIGHT, WIDTH))\n", " ov_pose_guider = nncf.compress_weights(ov_pose_guider)\n", " ov.save_model(ov_pose_guider, POSE_GUIDER_PATH)\n", " del ov_pose_guider\n", " cleanup_torchscript_cache()\n", "del pose_guider\n", "gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c4db1630-95d4-4cea-927f-fe1d1c259597", "metadata": {}, "source": [ "### Image Encoder\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Pipeline uses CLIP image encoder to generate encoder hidden states required for both reference and denoising UNets." ] }, { "cell_type": "code", "execution_count": 18, "id": "b7331636-b124-4038-9553-2292de72f13e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/itrushkin/.virtualenvs/test/lib/python3.10/site-packages/transformers/modeling_utils.py:4225: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:Statistics of the bitwidth distribution:\n", "+--------------+---------------------------+-----------------------------------+\n", "| Num bits (N) | % all parameters (layers) | % ratio-defining parameters |\n", "| | | (layers) |\n", "+==============+===========================+===================================+\n", "| 8 | 100% (146 / 146) | 100% (146 / 146) |\n", "+--------------+---------------------------+-----------------------------------+\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4bfabe55e887446fa5ef8c8fea37b33e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%skip not $SHOULD_CONVERT\n", "if not IMAGE_ENCODER_PATH.exists():\n", " image_enc.eval()\n", " with torch.no_grad():\n", " ov_image_encoder = ov.convert_model(image_enc, example_input=torch.zeros(1, 3, 224, 224), input=(1, 3, 224, 224))\n", " ov_image_encoder = nncf.compress_weights(ov_image_encoder)\n", " ov.save_model(ov_image_encoder, IMAGE_ENCODER_PATH)\n", " del ov_image_encoder\n", " cleanup_torchscript_cache()\n", "del image_enc\n", "gc.collect()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4e059780-29fc-4fe0-9376-c83573695fbe", "metadata": {}, "source": [ "## Inference\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "We inherit from the original pipeline modifying the calls to our models to match OpenVINO format." ] }, { "cell_type": "code", "execution_count": 19, "id": "35176d47-9a79-49dd-a61f-75892dba8d3d", "metadata": {}, "outputs": [], "source": [ "core = ov.Core()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7a28383f-8249-4c24-bc4c-8dcb0f75461c", "metadata": {}, "source": [ "### Select inference device\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "For starting work, please select inference device from dropdown list." ] }, { "cell_type": "code", "execution_count": 20, "id": "6f43558f-f244-43b7-9c82-c47b5b0bce23", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3e91c2a792224a4983cb2758f093b775", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dropdown(description='Device:', index=5, options=('CPU', 'GPU.0', 'GPU.1', 'GPU.2', 'GPU.3', 'AUTO'), value='A…" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = widgets.Dropdown(\n", " options=core.available_devices + [\"AUTO\"],\n", " value=\"AUTO\",\n", " description=\"Device:\",\n", " disabled=False,\n", ")\n", "\n", "device" ] }, { "cell_type": "code", "execution_count": 21, "id": "2c3cf752-cd75-4063-b56a-cade1894f926", "metadata": {}, "outputs": [], "source": [ "class OVPose2VideoPipeline(Pose2VideoPipeline):\n", " def __init__(\n", " self,\n", " vae_encoder_path=VAE_ENCODER_PATH,\n", " vae_decoder_path=VAE_DECODER_PATH,\n", " image_encoder_path=IMAGE_ENCODER_PATH,\n", " reference_unet_path=REFERENCE_UNET_PATH,\n", " denoising_unet_path=DENOISING_UNET_PATH,\n", " pose_guider_path=POSE_GUIDER_PATH,\n", " device=device.value,\n", " ):\n", " self.vae_encoder = core.compile_model(vae_encoder_path, device)\n", " self.vae_decoder = core.compile_model(vae_decoder_path, device)\n", " self.image_encoder = core.compile_model(image_encoder_path, device)\n", " self.reference_unet = core.compile_model(reference_unet_path, device)\n", " self.denoising_unet = core.compile_model(denoising_unet_path, device)\n", " self.pose_guider = core.compile_model(pose_guider_path, device)\n", " self.scheduler = DDIMScheduler(**OmegaConf.to_container(infer_config.noise_scheduler_kwargs))\n", "\n", " self.vae_scale_factor = 8\n", " self.clip_image_processor = CLIPImageProcessor()\n", " self.ref_image_processor = VaeImageProcessor(do_convert_rgb=True)\n", " self.cond_image_processor = VaeImageProcessor(do_convert_rgb=True, do_normalize=False)\n", "\n", " def decode_latents(self, latents):\n", " video_length = latents.shape[2]\n", " latents = 1 / 0.18215 * latents\n", " latents = rearrange(latents, \"b c f h w -> (b f) c h w\")\n", " # video = self.vae.decode(latents).sample\n", " video = []\n", " for frame_idx in tqdm(range(latents.shape[0])):\n", " video.append(torch.from_numpy(self.vae_decoder(latents[frame_idx : frame_idx + 1])[0]))\n", " video = torch.cat(video)\n", " video = rearrange(video, \"(b f) c h w -> b c f h w\", f=video_length)\n", " video = (video / 2 + 0.5).clamp(0, 1)\n", " # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16\n", " video = video.cpu().float().numpy()\n", " return video\n", "\n", " def __call__(\n", " self,\n", " ref_image,\n", " pose_images,\n", " width,\n", " height,\n", " video_length,\n", " num_inference_steps=30,\n", " guidance_scale=3.5,\n", " num_images_per_prompt=1,\n", " eta: float = 0.0,\n", " generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n", " output_type: Optional[str] = \"tensor\",\n", " callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n", " callback_steps: Optional[int] = 1,\n", " context_schedule=\"uniform\",\n", " context_frames=24,\n", " context_stride=1,\n", " context_overlap=4,\n", " context_batch_size=1,\n", " interpolation_factor=1,\n", " **kwargs,\n", " ):\n", " do_classifier_free_guidance = guidance_scale > 1.0\n", "\n", " # Prepare timesteps\n", " self.scheduler.set_timesteps(num_inference_steps)\n", " timesteps = self.scheduler.timesteps\n", "\n", " batch_size = 1\n", "\n", " # Prepare clip image embeds\n", " clip_image = self.clip_image_processor.preprocess(ref_image.resize((224, 224)), return_tensors=\"pt\").pixel_values\n", " clip_image_embeds = self.image_encoder(clip_image)[\"image_embeds\"]\n", " clip_image_embeds = torch.from_numpy(clip_image_embeds)\n", " encoder_hidden_states = clip_image_embeds.unsqueeze(1)\n", " uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)\n", "\n", " if do_classifier_free_guidance:\n", " encoder_hidden_states = torch.cat([uncond_encoder_hidden_states, encoder_hidden_states], dim=0)\n", "\n", " latents = self.prepare_latents(\n", " batch_size * num_images_per_prompt,\n", " 4,\n", " width,\n", " height,\n", " video_length,\n", " clip_image_embeds.dtype,\n", " torch.device(\"cpu\"),\n", " generator,\n", " )\n", "\n", " # Prepare extra step kwargs.\n", " extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n", "\n", " # Prepare ref image latents\n", " ref_image_tensor = self.ref_image_processor.preprocess(ref_image, height=height, width=width) # (bs, c, width, height)\n", " ref_image_latents = self.vae_encoder(ref_image_tensor)[0]\n", " ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)\n", " ref_image_latents = torch.from_numpy(ref_image_latents)\n", "\n", " # Prepare a list of pose condition images\n", " pose_cond_tensor_list = []\n", " for pose_image in pose_images:\n", " pose_cond_tensor = self.cond_image_processor.preprocess(pose_image, height=height, width=width)\n", " pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)\n", " pose_cond_tensor_list.append(pose_cond_tensor)\n", " pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=2) # (bs, c, t, h, w)\n", " pose_fea = self.pose_guider(pose_cond_tensor)[0]\n", " pose_fea = torch.from_numpy(pose_fea)\n", "\n", " context_scheduler = get_context_scheduler(context_schedule)\n", "\n", " # denoising loop\n", " num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n", " with self.progress_bar(total=num_inference_steps) as progress_bar:\n", " for i, t in enumerate(timesteps):\n", " noise_pred = torch.zeros(\n", " (\n", " latents.shape[0] * (2 if do_classifier_free_guidance else 1),\n", " *latents.shape[1:],\n", " ),\n", " device=latents.device,\n", " dtype=latents.dtype,\n", " )\n", " counter = torch.zeros(\n", " (1, 1, latents.shape[2], 1, 1),\n", " device=latents.device,\n", " dtype=latents.dtype,\n", " )\n", "\n", " # 1. Forward reference image\n", " if i == 0:\n", " ref_features = self.reference_unet(\n", " (\n", " ref_image_latents.repeat((2 if do_classifier_free_guidance else 1), 1, 1, 1),\n", " torch.zeros_like(t),\n", " # t,\n", " encoder_hidden_states,\n", " )\n", " ).values()\n", "\n", " context_queue = list(\n", " context_scheduler(\n", " 0,\n", " num_inference_steps,\n", " latents.shape[2],\n", " context_frames,\n", " context_stride,\n", " 0,\n", " )\n", " )\n", " num_context_batches = math.ceil(len(context_queue) / context_batch_size)\n", "\n", " context_queue = list(\n", " context_scheduler(\n", " 0,\n", " num_inference_steps,\n", " latents.shape[2],\n", " context_frames,\n", " context_stride,\n", " context_overlap,\n", " )\n", " )\n", "\n", " num_context_batches = math.ceil(len(context_queue) / context_batch_size)\n", " global_context = []\n", " for i in range(num_context_batches):\n", " global_context.append(context_queue[i * context_batch_size : (i + 1) * context_batch_size])\n", "\n", " for context in global_context:\n", " # 3.1 expand the latents if we are doing classifier free guidance\n", " latent_model_input = torch.cat([latents[:, :, c] for c in context]).repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)\n", " latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n", " b, c, f, h, w = latent_model_input.shape\n", " latent_pose_input = torch.cat([pose_fea[:, :, c] for c in context]).repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)\n", "\n", " pred = self.denoising_unet(\n", " (\n", " latent_model_input,\n", " t,\n", " encoder_hidden_states[:b],\n", " latent_pose_input,\n", " *ref_features,\n", " )\n", " )[0]\n", "\n", " for j, c in enumerate(context):\n", " noise_pred[:, :, c] = noise_pred[:, :, c] + pred\n", " counter[:, :, c] = counter[:, :, c] + 1\n", "\n", " # perform guidance\n", " if do_classifier_free_guidance:\n", " noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)\n", " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n", "\n", " latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n", "\n", " if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n", " progress_bar.update()\n", " if callback is not None and i % callback_steps == 0:\n", " step_idx = i // getattr(self.scheduler, \"order\", 1)\n", " callback(step_idx, t, latents)\n", "\n", " if interpolation_factor > 0:\n", " latents = self.interpolate_latents(latents, interpolation_factor, latents.device)\n", " # Post-processing\n", " images = self.decode_latents(latents) # (b, c, f, h, w)\n", "\n", " # Convert to tensor\n", " if output_type == \"tensor\":\n", " images = torch.from_numpy(images)\n", "\n", " return images" ] }, { "cell_type": "code", "execution_count": 22, "id": "ca762050-692d-48ed-b31b-73130e712784", "metadata": {}, "outputs": [], "source": [ "pipe = OVPose2VideoPipeline()" ] }, { "cell_type": "code", "execution_count": 23, "id": "9a34c91c-20f5-4e5f-aed2-f3dc3c226e79", "metadata": {}, "outputs": [], "source": [ "pose_images = read_frames(\"Moore-AnimateAnyone/configs/inference/pose_videos/anyone-video-2_kps.mp4\")\n", "src_fps = get_fps(\"Moore-AnimateAnyone/configs/inference/pose_videos/anyone-video-2_kps.mp4\")\n", "ref_image = Image.open(\"Moore-AnimateAnyone/configs/inference/ref_images/anyone-5.png\").convert(\"RGB\")\n", "pose_list = []\n", "for pose_image_pil in pose_images[:VIDEO_LENGTH]:\n", " pose_list.append(pose_image_pil)" ] }, { "cell_type": "code", "execution_count": 24, "id": "a6636887-eb52-494f-86ee-d3da25746c14", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "de88200bdd334dafa57361beb0bcd027", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00 b c (repeat f) h w\", repeat=VIDEO_LENGTH)\n", "pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)\n", "pose_tensor = pose_tensor.transpose(0, 1)\n", "pose_tensor = pose_tensor.unsqueeze(0)\n", "video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)\n", "\n", "save_dir = Path(\"./output\")\n", "save_dir.mkdir(parents=True, exist_ok=True)\n", "date_str = datetime.now().strftime(\"%Y%m%d\")\n", "time_str = datetime.now().strftime(\"%H%M\")\n", "out_path = save_dir / f\"{date_str}T{time_str}.mp4\"\n", "save_videos_grid(\n", " video,\n", " str(out_path),\n", " n_rows=3,\n", " fps=src_fps,\n", ")" ] }, { "cell_type": "code", "execution_count": 26, "id": "08c82ebc-55b7-4899-a34d-dca208567125", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import Video\n", "\n", "Video(out_path, embed=True)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "7becfbd0-4cb2-41ee-a340-470011796add", "metadata": {}, "source": [ "## Interactive inference\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 27, "id": "2832a501-a8cb-4a44-a249-846f8524e3d6", "metadata": {}, "outputs": [], "source": [ "def generate(\n", " img,\n", " pose_vid,\n", " seed,\n", " guidance_scale,\n", " num_inference_steps,\n", " _=gr.Progress(track_tqdm=True),\n", "):\n", " generator = torch.Generator().manual_seed(seed)\n", " pose_list = read_frames(pose_vid)[:VIDEO_LENGTH]\n", " video = pipe(\n", " img,\n", " pose_list,\n", " width=WIDTH,\n", " height=HEIGHT,\n", " video_length=VIDEO_LENGTH,\n", " generator=generator,\n", " guidance_scale=guidance_scale,\n", " num_inference_steps=num_inference_steps,\n", " )\n", " new_h, new_w = video.shape[-2:]\n", " pose_transform = transforms.Compose([transforms.Resize((new_h, new_w)), transforms.ToTensor()])\n", " pose_tensor_list = []\n", " for pose_image_pil in pose_list:\n", " pose_tensor_list.append(pose_transform(pose_image_pil))\n", "\n", " ref_image_tensor = pose_transform(img) # (c, h, w)\n", " ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)\n", " ref_image_tensor = repeat(ref_image_tensor, \"b c f h w -> b c (repeat f) h w\", repeat=VIDEO_LENGTH)\n", " pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)\n", " pose_tensor = pose_tensor.transpose(0, 1)\n", " pose_tensor = pose_tensor.unsqueeze(0)\n", " video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0)\n", "\n", " save_dir = Path(\"./output/gradio\")\n", " save_dir.mkdir(parents=True, exist_ok=True)\n", " date_str = datetime.now().strftime(\"%Y%m%d\")\n", " time_str = datetime.now().strftime(\"%H%M\")\n", " out_path = save_dir / f\"{date_str}T{time_str}.mp4\"\n", " save_videos_grid(\n", " video,\n", " str(out_path),\n", " n_rows=3,\n", " fps=12,\n", " )\n", " return out_path\n", "\n", "\n", "demo = gr.Interface(\n", " generate,\n", " [\n", " gr.Image(label=\"Reference Image\", type=\"pil\"),\n", " gr.Video(label=\"Pose video\"),\n", " gr.Slider(\n", " label=\"Seed\",\n", " value=42,\n", " minimum=np.iinfo(np.int32).min,\n", " maximum=np.iinfo(np.int32).max,\n", " ),\n", " gr.Slider(label=\"Guidance scale\", value=3.5, minimum=1.1, maximum=10),\n", " gr.Slider(label=\"Number of inference steps\", value=30, minimum=15, maximum=100),\n", " ],\n", " \"video\",\n", " examples=[\n", " [\n", " \"Moore-AnimateAnyone/configs/inference/ref_images/anyone-2.png\",\n", " \"Moore-AnimateAnyone/configs/inference/pose_videos/anyone-video-2_kps.mp4\",\n", " ],\n", " [\n", " \"Moore-AnimateAnyone/configs/inference/ref_images/anyone-10.png\",\n", " \"Moore-AnimateAnyone/configs/inference/pose_videos/anyone-video-1_kps.mp4\",\n", " ],\n", " [\n", " \"Moore-AnimateAnyone/configs/inference/ref_images/anyone-11.png\",\n", " \"Moore-AnimateAnyone/configs/inference/pose_videos/anyone-video-1_kps.mp4\",\n", " ],\n", " [\n", " \"Moore-AnimateAnyone/configs/inference/ref_images/anyone-3.png\",\n", " \"Moore-AnimateAnyone/configs/inference/pose_videos/anyone-video-2_kps.mp4\",\n", " ],\n", " [\n", " \"Moore-AnimateAnyone/configs/inference/ref_images/anyone-5.png\",\n", " \"Moore-AnimateAnyone/configs/inference/pose_videos/anyone-video-2_kps.mp4\",\n", " ],\n", " ],\n", " allow_flagging=\"never\",\n", ")\n", "try:\n", " demo.queue().launch(debug=True)\n", "except Exception:\n", " demo.queue().launch(debug=True, share=True)\n", "# if you are launching remotely, specify server_name and server_port\n", "# demo.launch(server_name='your server name', server_port='server port in int')\n", "# Read more in the docs: https://gradio.app/docs/\"" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" }, "openvino_notebooks": { "imageUrl": "https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/e11be946-dfa4-4f1d-a8b8-764213de9f1c", "tags": { "categories": [ "Model Demos", "AI Trends" ], "libraries": [], "other": [], "tasks": [ "Image-to-Video" ] } } }, "nbformat": 4, "nbformat_minor": 5 }