AvatarArtist / DiT_VAE /train_vae.py
刘虹雨
update
8ed2f16
import argparse
import math
import os
import sys
current_path = os.path.abspath(__file__)
father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".")
sys.path.append((os.path.join(father_path, 'Next3d')))
from typing import Dict, Optional, Tuple
from omegaconf import OmegaConf
import torch
import logging
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import inspect
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
import dnnlib
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
from vae.triplane_vae import AutoencoderKL, AutoencoderKLRollOut
from vae.data.dataset_online_vae import TriplaneDataset
from einops import rearrange
from vae.utils.common_utils import instantiate_from_config
from Next3d.training_avatar_texture.triplane_generation import TriPlaneGenerator
import Next3d.legacy as legacy
from torch_utils import misc
import datetime
logger = get_logger(__name__, log_level="INFO")
def collate_fn(data):
model_names = [example["data_model_name"] for example in data]
zs = torch.cat([example["data_z"] for example in data], dim=0)
verts = torch.cat([example["data_vert"] for example in data], dim=0)
return {
'model_names': model_names,
'zs': zs,
'verts': verts
}
def rollout_fn(triplane):
triplane = rearrange(triplane, "b c f h w -> b f c h w")
b, f, c, h, w = triplane.shape
triplane = triplane.permute(0, 2, 3, 1, 4).reshape(-1, c, h, f * w)
return triplane
def unrollout_fn(triplane):
res = triplane.shape[-2]
ch = triplane.shape[1]
triplane = triplane.reshape(-1, ch // 3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, 3, ch, res, res)
triplane = rearrange(triplane, "b f c h w -> b c f h w")
return triplane
def triplane_generate(G_model, z, conditioning_params, std, mean, truncation_psi=0.7, truncation_cutoff=14):
w = G_model.mapping(z, conditioning_params, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
triplane = G_model.synthesis(w, noise_mode='const')
triplane = (triplane - mean) / std
return triplane
def gan_model(gan_models, device, gan_model_base_dir):
gan_model_dict = gan_models
gan_model_load = {}
for model_name in gan_model_dict.keys():
model_pkl = os.path.join(gan_model_base_dir, model_name + '.pkl')
with dnnlib.util.open_url(model_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device)
misc.copy_params_and_buffers(G, G_new, require_all=True)
G_new.neural_rendering_resolution = G.neural_rendering_resolution
G_new.rendering_kwargs = G.rendering_kwargs
gan_model_load[model_name] = G_new
return gan_model_load
def main(vae_config: str,
gan_model_config: str,
output_dir: str,
std_dir: str,
mean_dir: str,
conditioning_params_dir: str,
gan_model_base_dir: str,
train_data: Dict,
train_batch_size: int = 2,
max_train_steps: int = 500,
learning_rate: float = 3e-5,
scale_lr: bool = False,
lr_scheduler: str = "constant",
lr_warmup_steps: int = 0,
adam_beta1: float = 0.5,
adam_beta2: float = 0.9,
adam_weight_decay: float = 1e-2,
adam_epsilon: float = 1e-08,
max_grad_norm: float = 1.0,
gradient_accumulation_steps: int = 1,
gradient_checkpointing: bool = True,
checkpointing_steps: int = 500,
pretrained_model_path_zero123: str = None,
resume_from_checkpoint: Optional[str] = None,
mixed_precision: Optional[str] = "fp16",
use_8bit_adam: bool = False,
rollout: bool = False,
enable_xformers_memory_efficient_attention: bool = True,
seed: Optional[int] = None, ):
*_, config = inspect.getargvalues(inspect.currentframe())
base_dir = output_dir
accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision=mixed_precision,
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
# If passed along, set the training seed now.
if seed is not None:
set_seed(seed)
if accelerator.is_main_process:
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
output_dir = os.path.join(output_dir, now)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/samples", exist_ok=True)
os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
config_vae = OmegaConf.load(vae_config)
if rollout:
vae = AutoencoderKLRollOut(ddconfig=config_vae['ddconfig'], lossconfig=config_vae['lossconfig'], embed_dim=8)
else:
vae = AutoencoderKL(ddconfig=config_vae['ddconfig'], lossconfig=config_vae['lossconfig'], embed_dim=8)
print(f"VAE total params = {len(list(vae.named_parameters()))} ")
if 'perceptual_weight' in config_vae['lossconfig']['params'].keys():
config_vae['lossconfig']['params']['device'] = str(accelerator.device)
loss_fn = instantiate_from_config(config_vae['lossconfig'])
conditioning_params = torch.load(conditioning_params_dir).to(str(accelerator.device))
data_std = torch.load(std_dir).to(str(accelerator.device)).reshape(1, -1, 1, 1, 1)
data_mean = torch.load(mean_dir).to(str(accelerator.device)).reshape(1, -1, 1, 1, 1)
# define the gan model
print("########## gan model load ##########")
config_gan_model = OmegaConf.load(gan_model_config)
gan_model_all = gan_model(config_gan_model['gan_models'], str(accelerator.device), gan_model_base_dir)
print("########## gan model loaded ##########")
if scale_lr:
learning_rate = (
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
vae.parameters(),
lr=learning_rate,
betas=(adam_beta1, adam_beta2),
weight_decay=adam_weight_decay,
eps=adam_epsilon,
)
train_dataset = TriplaneDataset(**train_data)
# Preprocessing the dataset
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, collate_fn=collate_fn, shuffle=True, num_workers=2
)
lr_scheduler = get_scheduler(
lr_scheduler,
optimizer=optimizer,
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
num_training_steps=max_train_steps * gradient_accumulation_steps,
)
vae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
vae, optimizer, train_dataloader, lr_scheduler
)
weight_dtype = torch.float32
# Move text_encode and vae to gpu and cast to weight_dtype
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae.to(accelerator.device, dtype=weight_dtype)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("trainvae", config=vars(args))
# Train!
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if resume_from_checkpoint:
if resume_from_checkpoint != "latest":
path = os.path.basename(resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1]
accelerator.print(f"Resuming from checkpoint {path}")
if resume_from_checkpoint != "latest":
accelerator.load_state(resume_from_checkpoint)
else:
accelerator.load_state(os.path.join(output_dir, path))
global_step = int(path.split("-")[1])
first_epoch = global_step // num_update_steps_per_epoch
resume_step = global_step % num_update_steps_per_epoch
else:
all_final_training_dirs = []
dirs = os.listdir(base_dir)
if len(dirs) != 0:
dirs = [d for d in dirs if d.startswith("2024")] # specific years
if len(dirs) != 0:
base_resume_paths = [os.path.join(base_dir, d) for d in dirs]
for base_resume_path in base_resume_paths:
checkpoint_file_names = os.listdir(base_resume_path)
checkpoint_file_names = [d for d in checkpoint_file_names if d.startswith("checkpoint")]
if len(checkpoint_file_names) != 0:
for checkpoint_file_name in checkpoint_file_names:
final_training_dir = os.path.join(base_resume_path, checkpoint_file_name)
all_final_training_dirs.append(final_training_dir)
if len(all_final_training_dirs) != 0:
sorted_all_final_training_dirs = sorted(all_final_training_dirs, key=lambda x: int(x.split("-")[1]))
latest_dir = sorted_all_final_training_dirs[-1]
path = os.path.basename( latest_dir)
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(latest_dir)
global_step = int(path.split("-")[1])
first_epoch = global_step // num_update_steps_per_epoch
resume_step = global_step % num_update_steps_per_epoch
else:
accelerator.print(f"Training from start")
else:
accelerator.print(f"Training from start")
else:
accelerator.print(f"Training from start")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, num_train_epochs):
vae.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
# print(epoch)
# print(first_epoch)
# print(step)
# if step % gradient_accumulation_steps == 0:
# progress_bar.update(1)
# continue
with accelerator.accumulate(vae):
# Convert images to latent space
z_values = batch["zs"].to(weight_dtype)
model_names = batch["model_names"]
triplane_values = []
with torch.no_grad():
for z_id in range(z_values.shape[0]):
z_value = z_values[z_id].unsqueeze(0)
model_name = model_names[z_id]
triplane_value = triplane_generate(gan_model_all[model_name], z_value,
conditioning_params, data_std, data_mean)
triplane_values.append(triplane_value)
triplane_values = torch.cat(triplane_values, dim=0)
vert_values = batch["verts"].to(weight_dtype)
triplane_values = rearrange(triplane_values, "b f c h w -> b c f h w")
if rollout:
triplane_values_roll = rollout_fn(triplane_values.clone())
reconstructions, posterior = vae(triplane_values_roll)
reconstructions_unroll = unrollout_fn(reconstructions)
loss, log_dict_ae = loss_fn(triplane_values, reconstructions_unroll, posterior, vert_values,
split="train")
else:
reconstructions, posterior = vae(triplane_values)
loss, log_dict_ae = loss_fn(triplane_values, reconstructions, posterior, vert_values,
split="train")
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(vae.parameters(), max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
# logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
logs = log_dict_ae
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= max_train_steps:
break
accelerator.wait_for_everyone()
accelerator.end_training()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/triplane_vae.yaml")
args = parser.parse_args()
main(**OmegaConf.load(args.config))