|
""" |
|
This script is a gradio web ui. |
|
|
|
The script takes an image and an audio clip, and lets you configure all the |
|
variables such as cfg_scale, pose_weight, face_weight, lip_weight, etc. |
|
|
|
Usage: |
|
This script can be run from the command line with the following command: |
|
|
|
python scripts/app.py |
|
""" |
|
|
|
import gradio as gr |
|
import argparse |
|
import copy |
|
import logging |
|
import math |
|
import os |
|
import random |
|
import time |
|
import warnings |
|
from datetime import datetime |
|
from typing import List, Tuple |
|
|
|
import diffusers |
|
import mlflow |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
import transformers |
|
from accelerate import Accelerator |
|
from accelerate.logging import get_logger |
|
from accelerate.utils import DistributedDataParallelKwargs |
|
from diffusers import AutoencoderKL, DDIMScheduler |
|
from diffusers.optimization import get_scheduler |
|
from diffusers.utils import check_min_version |
|
from diffusers.utils.import_utils import is_xformers_available |
|
from einops import rearrange, repeat |
|
from omegaconf import OmegaConf |
|
from torch import nn |
|
from tqdm.auto import tqdm |
|
import uuid |
|
|
|
import sys |
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) |
|
|
|
from joyhallo.animate.face_animate import FaceAnimatePipeline |
|
from joyhallo.datasets.audio_processor import AudioProcessor |
|
from joyhallo.datasets.image_processor import ImageProcessor |
|
from joyhallo.datasets.talk_video import TalkingVideoDataset |
|
from joyhallo.models.audio_proj import AudioProjModel |
|
from joyhallo.models.face_locator import FaceLocator |
|
from joyhallo.models.image_proj import ImageProjModel |
|
from joyhallo.models.mutual_self_attention import ReferenceAttentionControl |
|
from joyhallo.models.unet_2d_condition import UNet2DConditionModel |
|
from joyhallo.models.unet_3d import UNet3DConditionModel |
|
from joyhallo.utils.util import (compute_snr, delete_additional_ckpt, |
|
import_filename, init_output_dir, |
|
load_checkpoint, save_checkpoint, |
|
seed_everything, tensor_to_video) |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
check_min_version("0.10.0.dev0") |
|
|
|
logger = get_logger(__name__, log_level="INFO") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class Net(nn.Module): |
|
""" |
|
The Net class defines a neural network model that combines a reference UNet2DConditionModel, |
|
a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image. |
|
|
|
Args: |
|
reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation. |
|
denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation. |
|
face_locator (FaceLocator): The face locator model used for face animation. |
|
reference_control_writer: The reference control writer component. |
|
reference_control_reader: The reference control reader component. |
|
imageproj: The image projection model. |
|
audioproj: The audio projection model. |
|
|
|
Forward method: |
|
noisy_latents (torch.Tensor): The noisy latents tensor. |
|
timesteps (torch.Tensor): The timesteps tensor. |
|
ref_image_latents (torch.Tensor): The reference image latents tensor. |
|
face_emb (torch.Tensor): The face embeddings tensor. |
|
audio_emb (torch.Tensor): The audio embeddings tensor. |
|
mask (torch.Tensor): Hard face mask for face locator. |
|
full_mask (torch.Tensor): Pose Mask. |
|
face_mask (torch.Tensor): Face Mask |
|
lip_mask (torch.Tensor): Lip Mask |
|
uncond_img_fwd (bool): A flag indicating whether to perform reference image unconditional forward pass. |
|
uncond_audio_fwd (bool): A flag indicating whether to perform audio unconditional forward pass. |
|
|
|
Returns: |
|
torch.Tensor: The output tensor of the neural network model. |
|
""" |
|
def __init__( |
|
self, |
|
reference_unet: UNet2DConditionModel, |
|
denoising_unet: UNet3DConditionModel, |
|
face_locator: FaceLocator, |
|
reference_control_writer, |
|
reference_control_reader, |
|
imageproj, |
|
audioproj, |
|
): |
|
super().__init__() |
|
self.reference_unet = reference_unet |
|
self.denoising_unet = denoising_unet |
|
self.face_locator = face_locator |
|
self.reference_control_writer = reference_control_writer |
|
self.reference_control_reader = reference_control_reader |
|
self.imageproj = imageproj |
|
self.audioproj = audioproj |
|
|
|
def forward( |
|
self, |
|
noisy_latents: torch.Tensor, |
|
timesteps: torch.Tensor, |
|
ref_image_latents: torch.Tensor, |
|
face_emb: torch.Tensor, |
|
audio_emb: torch.Tensor, |
|
mask: torch.Tensor, |
|
full_mask: torch.Tensor, |
|
face_mask: torch.Tensor, |
|
lip_mask: torch.Tensor, |
|
uncond_img_fwd: bool = False, |
|
uncond_audio_fwd: bool = False, |
|
): |
|
""" |
|
simple docstring to prevent pylint error |
|
""" |
|
face_emb = self.imageproj(face_emb) |
|
mask = mask.to(device=device) |
|
mask_feature = self.face_locator(mask) |
|
audio_emb = audio_emb.to( |
|
device=self.audioproj.device, dtype=self.audioproj.dtype) |
|
audio_emb = self.audioproj(audio_emb) |
|
|
|
|
|
if not uncond_img_fwd: |
|
ref_timesteps = torch.zeros_like(timesteps) |
|
ref_timesteps = repeat( |
|
ref_timesteps, |
|
"b -> (repeat b)", |
|
repeat=ref_image_latents.size(0) // ref_timesteps.size(0), |
|
) |
|
self.reference_unet( |
|
ref_image_latents, |
|
ref_timesteps, |
|
encoder_hidden_states=face_emb, |
|
return_dict=False, |
|
) |
|
self.reference_control_reader.update(self.reference_control_writer) |
|
|
|
if uncond_audio_fwd: |
|
audio_emb = torch.zeros_like(audio_emb).to( |
|
device=audio_emb.device, dtype=audio_emb.dtype |
|
) |
|
|
|
model_pred = self.denoising_unet( |
|
noisy_latents, |
|
timesteps, |
|
mask_cond_fea=mask_feature, |
|
encoder_hidden_states=face_emb, |
|
audio_embedding=audio_emb, |
|
full_mask=full_mask, |
|
face_mask=face_mask, |
|
lip_mask=lip_mask |
|
).sample |
|
|
|
return model_pred |
|
|
|
|
|
def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor: |
|
""" |
|
Rearrange the mask tensors to the required format. |
|
|
|
Args: |
|
mask (torch.Tensor): The input mask tensor. |
|
weight_dtype (torch.dtype): The data type for the mask tensor. |
|
|
|
Returns: |
|
torch.Tensor: The rearranged mask tensor. |
|
""" |
|
if isinstance(mask, List): |
|
_mask = [] |
|
for m in mask: |
|
_mask.append( |
|
rearrange(m, "b f 1 h w -> (b f) (h w)").to(weight_dtype)) |
|
return _mask |
|
mask = rearrange(mask, "b f 1 h w -> (b f) (h w)").to(weight_dtype) |
|
return mask |
|
|
|
|
|
def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler, DDIMScheduler]: |
|
""" |
|
Create noise scheduler for training. |
|
|
|
Args: |
|
cfg (argparse.Namespace): Configuration object. |
|
|
|
Returns: |
|
Tuple[DDIMScheduler, DDIMScheduler]: Train noise scheduler and validation noise scheduler. |
|
""" |
|
|
|
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) |
|
if cfg.enable_zero_snr: |
|
sched_kwargs.update( |
|
rescale_betas_zero_snr=True, |
|
timestep_spacing="trailing", |
|
prediction_type="v_prediction", |
|
) |
|
val_noise_scheduler = DDIMScheduler(**sched_kwargs) |
|
sched_kwargs.update({"beta_schedule": "scaled_linear"}) |
|
train_noise_scheduler = DDIMScheduler(**sched_kwargs) |
|
|
|
return train_noise_scheduler, val_noise_scheduler |
|
|
|
|
|
def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Process the audio embedding to concatenate with other tensors. |
|
|
|
Parameters: |
|
audio_emb (torch.Tensor): The audio embedding tensor to process. |
|
|
|
Returns: |
|
concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. |
|
""" |
|
concatenated_tensors = [] |
|
|
|
for i in range(audio_emb.shape[0]): |
|
vectors_to_concat = [ |
|
audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)]for j in range(-2, 3)] |
|
concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) |
|
|
|
audio_emb = torch.stack(concatenated_tensors, dim=0) |
|
|
|
return audio_emb |
|
|
|
|
|
def log_validation( |
|
accelerator: Accelerator, |
|
vae: AutoencoderKL, |
|
net: Net, |
|
scheduler: DDIMScheduler, |
|
width: int, |
|
height: int, |
|
clip_length: int = 24, |
|
generator: torch.Generator = None, |
|
cfg: dict = None, |
|
save_dir: str = None, |
|
global_step: int = 0, |
|
times: int = None, |
|
face_analysis_model_path: str = "", |
|
) -> None: |
|
""" |
|
Log validation video during the training process. |
|
|
|
Args: |
|
accelerator (Accelerator): The accelerator for distributed training. |
|
vae (AutoencoderKL): The autoencoder model. |
|
net (Net): The main neural network model. |
|
scheduler (DDIMScheduler): The scheduler for noise. |
|
width (int): The width of the input images. |
|
height (int): The height of the input images. |
|
clip_length (int): The length of the video clips. Defaults to 24. |
|
generator (torch.Generator): The random number generator. Defaults to None. |
|
cfg (dict): The configuration dictionary. Defaults to None. |
|
save_dir (str): The directory to save validation results. Defaults to None. |
|
global_step (int): The current global step in training. Defaults to 0. |
|
times (int): The number of inference times. Defaults to None. |
|
face_analysis_model_path (str): The path to the face analysis model. Defaults to "". |
|
|
|
Returns: |
|
torch.Tensor: The tensor result of the validation. |
|
""" |
|
ori_net = accelerator.unwrap_model(net) |
|
reference_unet = ori_net.reference_unet |
|
denoising_unet = ori_net.denoising_unet |
|
face_locator = ori_net.face_locator |
|
imageproj = ori_net.imageproj |
|
audioproj = ori_net.audioproj |
|
tmp_denoising_unet = copy.deepcopy(denoising_unet) |
|
|
|
pipeline = FaceAnimatePipeline( |
|
vae=vae, |
|
reference_unet=reference_unet, |
|
denoising_unet=tmp_denoising_unet, |
|
face_locator=face_locator, |
|
image_proj=imageproj, |
|
scheduler=scheduler, |
|
) |
|
pipeline = pipeline.to(device) |
|
|
|
image_processor = ImageProcessor((width, height), face_analysis_model_path) |
|
audio_processor = AudioProcessor( |
|
cfg.data.sample_rate, |
|
cfg.data.fps, |
|
cfg.wav2vec_config.model_path, |
|
cfg.wav2vec_config.features == "last", |
|
os.path.dirname(cfg.audio_separator.model_path), |
|
os.path.basename(cfg.audio_separator.model_path), |
|
os.path.join(save_dir, '.cache', "audio_preprocess"), |
|
device=device, |
|
) |
|
return cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length |
|
|
|
|
|
def inference(cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length): |
|
ref_img_path = cfg.ref_img_path |
|
audio_path = cfg.audio_path |
|
source_image_pixels, \ |
|
source_image_face_region, \ |
|
source_image_face_emb, \ |
|
source_image_full_mask, \ |
|
source_image_face_mask, \ |
|
source_image_lip_mask = image_processor.preprocess( |
|
ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio) |
|
audio_emb, audio_length = audio_processor.preprocess( |
|
audio_path, clip_length) |
|
|
|
audio_emb = process_audio_emb(audio_emb) |
|
|
|
source_image_pixels = source_image_pixels.unsqueeze(0) |
|
source_image_face_region = source_image_face_region.unsqueeze(0) |
|
source_image_face_emb = source_image_face_emb.reshape(1, -1) |
|
source_image_face_emb = torch.tensor(source_image_face_emb) |
|
|
|
source_image_full_mask = [ |
|
(mask.repeat(clip_length, 1)) |
|
for mask in source_image_full_mask |
|
] |
|
source_image_face_mask = [ |
|
(mask.repeat(clip_length, 1)) |
|
for mask in source_image_face_mask |
|
] |
|
source_image_lip_mask = [ |
|
(mask.repeat(clip_length, 1)) |
|
for mask in source_image_lip_mask |
|
] |
|
|
|
times = audio_emb.shape[0] // clip_length |
|
tensor_result = [] |
|
generator = torch.manual_seed(42) |
|
for t in range(times): |
|
print(f"[{t+1}/{times}]") |
|
|
|
if len(tensor_result) == 0: |
|
|
|
motion_zeros = source_image_pixels.repeat( |
|
cfg.data.n_motion_frames, 1, 1, 1) |
|
motion_zeros = motion_zeros.to( |
|
dtype=source_image_pixels.dtype, device=source_image_pixels.device) |
|
pixel_values_ref_img = torch.cat( |
|
[source_image_pixels, motion_zeros], dim=0) |
|
else: |
|
motion_frames = tensor_result[-1][0] |
|
motion_frames = motion_frames.permute(1, 0, 2, 3) |
|
motion_frames = motion_frames[0 - cfg.data.n_motion_frames:] |
|
motion_frames = motion_frames * 2.0 - 1.0 |
|
motion_frames = motion_frames.to( |
|
dtype=source_image_pixels.dtype, device=source_image_pixels.device) |
|
pixel_values_ref_img = torch.cat( |
|
[source_image_pixels, motion_frames], dim=0) |
|
|
|
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) |
|
|
|
audio_tensor = audio_emb[ |
|
t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) |
|
] |
|
audio_tensor = audio_tensor.unsqueeze(0) |
|
audio_tensor = audio_tensor.to( |
|
device=audioproj.device, dtype=audioproj.dtype) |
|
audio_tensor = audioproj(audio_tensor) |
|
|
|
pipeline_output = pipeline( |
|
ref_image=pixel_values_ref_img, |
|
audio_tensor=audio_tensor, |
|
face_emb=source_image_face_emb, |
|
face_mask=source_image_face_region, |
|
pixel_values_full_mask=source_image_full_mask, |
|
pixel_values_face_mask=source_image_face_mask, |
|
pixel_values_lip_mask=source_image_lip_mask, |
|
width=cfg.data.train_width, |
|
height=cfg.data.train_height, |
|
video_length=clip_length, |
|
num_inference_steps=cfg.inference_steps, |
|
guidance_scale=cfg.cfg_scale, |
|
generator=generator, |
|
) |
|
|
|
tensor_result.append(pipeline_output.videos) |
|
|
|
tensor_result = torch.cat(tensor_result, dim=2) |
|
tensor_result = tensor_result.squeeze(0) |
|
tensor_result = tensor_result[:, :audio_length] |
|
output_file = cfg.output |
|
tensor_to_video(tensor_result, output_file, audio_path) |
|
return output_file |
|
|
|
|
|
def get_model(cfg: argparse.Namespace) -> None: |
|
""" |
|
Trains the model using the given configuration (cfg). |
|
|
|
Args: |
|
cfg (dict): The configuration dictionary containing the parameters for training. |
|
|
|
Notes: |
|
- This function trains the model using the given configuration. |
|
- It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler. |
|
- The training progress is logged and tracked using the accelerator. |
|
- The trained model is saved after the training is completed. |
|
""" |
|
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) |
|
accelerator = Accelerator( |
|
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, |
|
mixed_precision=cfg.solver.mixed_precision, |
|
log_with="mlflow", |
|
project_dir="./mlruns", |
|
kwargs_handlers=[kwargs], |
|
) |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
logger.info(accelerator.state, main_process_only=False) |
|
if accelerator.is_local_main_process: |
|
transformers.utils.logging.set_verbosity_warning() |
|
diffusers.utils.logging.set_verbosity_info() |
|
else: |
|
transformers.utils.logging.set_verbosity_error() |
|
diffusers.utils.logging.set_verbosity_error() |
|
|
|
|
|
if cfg.seed is not None: |
|
seed_everything(cfg.seed) |
|
|
|
|
|
exp_name = cfg.exp_name |
|
save_dir = f"{cfg.output_dir}/{exp_name}" |
|
validation_dir = save_dir |
|
if accelerator.is_main_process: |
|
init_output_dir([save_dir]) |
|
|
|
accelerator.wait_for_everyone() |
|
|
|
if cfg.weight_dtype == "fp16": |
|
weight_dtype = torch.float16 |
|
elif cfg.weight_dtype == "bf16": |
|
weight_dtype = torch.bfloat16 |
|
elif cfg.weight_dtype == "fp32": |
|
weight_dtype = torch.float32 |
|
else: |
|
raise ValueError( |
|
f"Do not support weight dtype: {cfg.weight_dtype} during training" |
|
) |
|
|
|
if not torch.cuda.is_available(): |
|
weight_dtype = torch.float32 |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( |
|
device=device, dtype=weight_dtype |
|
) |
|
reference_unet = UNet2DConditionModel.from_pretrained( |
|
cfg.base_model_path, |
|
subfolder="unet", |
|
).to(device=device, dtype=weight_dtype) |
|
denoising_unet = UNet3DConditionModel.from_pretrained_2d( |
|
cfg.base_model_path, |
|
cfg.mm_path, |
|
subfolder="unet", |
|
unet_additional_kwargs=OmegaConf.to_container( |
|
cfg.unet_additional_kwargs), |
|
use_landmark=False |
|
).to(device=device, dtype=weight_dtype) |
|
imageproj = ImageProjModel( |
|
cross_attention_dim=denoising_unet.config.cross_attention_dim, |
|
clip_embeddings_dim=512, |
|
clip_extra_context_tokens=4, |
|
).to(device=device, dtype=weight_dtype) |
|
face_locator = FaceLocator( |
|
conditioning_embedding_channels=320, |
|
).to(device=device, dtype=weight_dtype) |
|
audioproj = AudioProjModel( |
|
seq_len=5, |
|
blocks=12, |
|
channels=768, |
|
intermediate_dim=512, |
|
output_dim=768, |
|
context_tokens=32, |
|
).to(device=device, dtype=weight_dtype) |
|
|
|
|
|
vae.requires_grad_(False) |
|
imageproj.requires_grad_(False) |
|
reference_unet.requires_grad_(False) |
|
denoising_unet.requires_grad_(False) |
|
face_locator.requires_grad_(False) |
|
audioproj.requires_grad_(True) |
|
|
|
|
|
trainable_modules = cfg.trainable_para |
|
for name, module in denoising_unet.named_modules(): |
|
if any(trainable_mod in name for trainable_mod in trainable_modules): |
|
for params in module.parameters(): |
|
params.requires_grad_(True) |
|
|
|
reference_control_writer = ReferenceAttentionControl( |
|
reference_unet, |
|
do_classifier_free_guidance=False, |
|
mode="write", |
|
fusion_blocks="full", |
|
) |
|
reference_control_reader = ReferenceAttentionControl( |
|
denoising_unet, |
|
do_classifier_free_guidance=False, |
|
mode="read", |
|
fusion_blocks="full", |
|
) |
|
|
|
net = Net( |
|
reference_unet, |
|
denoising_unet, |
|
face_locator, |
|
reference_control_writer, |
|
reference_control_reader, |
|
imageproj, |
|
audioproj, |
|
).to(dtype=weight_dtype) |
|
|
|
m,u = net.load_state_dict( |
|
torch.load( |
|
cfg.audio_ckpt_dir, |
|
map_location="cpu", |
|
), |
|
) |
|
assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint." |
|
print("loaded weight from ", os.path.join(cfg.audio_ckpt_dir)) |
|
|
|
|
|
_, val_noise_scheduler = get_noise_scheduler(cfg) |
|
|
|
if cfg.solver.enable_xformers_memory_efficient_attention and torch.cuda.is_available(): |
|
if is_xformers_available(): |
|
reference_unet.enable_xformers_memory_efficient_attention() |
|
denoising_unet.enable_xformers_memory_efficient_attention() |
|
|
|
else: |
|
raise ValueError( |
|
"xformers is not available. Make sure it is installed correctly" |
|
) |
|
|
|
if cfg.solver.gradient_checkpointing: |
|
reference_unet.enable_gradient_checkpointing() |
|
denoising_unet.enable_gradient_checkpointing() |
|
|
|
if cfg.solver.scale_lr: |
|
learning_rate = ( |
|
cfg.solver.learning_rate |
|
* cfg.solver.gradient_accumulation_steps |
|
* cfg.data.train_bs |
|
* accelerator.num_processes |
|
) |
|
else: |
|
learning_rate = cfg.solver.learning_rate |
|
|
|
|
|
optimizer_cls = torch.optim.AdamW |
|
|
|
trainable_params = list( |
|
filter(lambda p: p.requires_grad, net.parameters())) |
|
|
|
optimizer = optimizer_cls( |
|
trainable_params, |
|
lr=learning_rate, |
|
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), |
|
weight_decay=cfg.solver.adam_weight_decay, |
|
eps=cfg.solver.adam_epsilon, |
|
) |
|
|
|
|
|
lr_scheduler = get_scheduler( |
|
cfg.solver.lr_scheduler, |
|
optimizer=optimizer, |
|
num_warmup_steps=cfg.solver.lr_warmup_steps |
|
* cfg.solver.gradient_accumulation_steps, |
|
num_training_steps=cfg.solver.max_train_steps |
|
* cfg.solver.gradient_accumulation_steps, |
|
) |
|
|
|
|
|
train_dataset = TalkingVideoDataset( |
|
img_size=(cfg.data.train_width, cfg.data.train_height), |
|
sample_rate=cfg.data.sample_rate, |
|
n_sample_frames=cfg.data.n_sample_frames, |
|
n_motion_frames=cfg.data.n_motion_frames, |
|
audio_margin=cfg.data.audio_margin, |
|
data_meta_paths=cfg.data.train_meta_paths, |
|
wav2vec_cfg=cfg.wav2vec_config, |
|
) |
|
train_dataloader = torch.utils.data.DataLoader( |
|
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=16 |
|
) |
|
|
|
|
|
( |
|
net, |
|
optimizer, |
|
train_dataloader, |
|
lr_scheduler, |
|
) = accelerator.prepare( |
|
net, |
|
optimizer, |
|
train_dataloader, |
|
lr_scheduler, |
|
) |
|
|
|
return accelerator, vae, net, val_noise_scheduler, cfg, validation_dir |
|
|
|
|
|
def load_config(config_path: str) -> dict: |
|
""" |
|
Loads the configuration file. |
|
|
|
Args: |
|
config_path (str): Path to the configuration file. |
|
|
|
Returns: |
|
dict: The configuration dictionary. |
|
""" |
|
|
|
if config_path.endswith(".yaml"): |
|
return OmegaConf.load(config_path) |
|
if config_path.endswith(".py"): |
|
return import_filename(config_path).cfg |
|
raise ValueError("Unsupported format for config file") |
|
|
|
args = argparse.Namespace() |
|
_config = load_config('configs/inference/inference.yaml') |
|
for key, value in _config.items(): |
|
setattr(args, key, value) |
|
accelerator, vae, net, val_noise_scheduler, cfg, validation_dir = get_model(args) |
|
cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length = log_validation( |
|
accelerator=accelerator, |
|
vae=vae, |
|
net=net, |
|
scheduler=val_noise_scheduler, |
|
width=cfg.data.train_width, |
|
height=cfg.data.train_height, |
|
clip_length=cfg.data.n_sample_frames, |
|
cfg=cfg, |
|
save_dir=validation_dir, |
|
global_step=0, |
|
times=cfg.single_inference_times if cfg.single_inference_times is not None else None, |
|
face_analysis_model_path=cfg.face_analysis_model_path |
|
) |
|
|
|
def predict(image, audio, pose_weight, face_weight, lip_weight, face_expand_ratio, progress=gr.Progress(track_tqdm=True)): |
|
""" |
|
Create a gradio interface with the configs. |
|
""" |
|
_ = progress |
|
unique_id = uuid.uuid4() |
|
config = { |
|
'ref_img_path': image, |
|
'audio_path': audio, |
|
'pose_weight': pose_weight, |
|
'face_weight': face_weight, |
|
'lip_weight': lip_weight, |
|
'face_expand_ratio': face_expand_ratio, |
|
'config': 'configs/inference/inference.yaml', |
|
'checkpoint': None, |
|
'output': f'output-{unique_id}.mp4' |
|
} |
|
global cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length |
|
for key, value in config.items(): |
|
setattr(cfg, key, value) |
|
|
|
return inference(cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length) |