Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import math | |
import os | |
import sys | |
current_path = os.path.abspath(__file__) | |
father_path = os.path.abspath(os.path.dirname(current_path) + os.path.sep + ".") | |
sys.path.append((os.path.join(father_path, 'Next3d'))) | |
from typing import Dict, Optional, Tuple | |
from omegaconf import OmegaConf | |
import torch | |
import logging | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from torch.utils.data import Dataset | |
import inspect | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import set_seed | |
import dnnlib | |
from diffusers.optimization import get_scheduler | |
from tqdm.auto import tqdm | |
from vae.triplane_vae import AutoencoderKL, AutoencoderKLRollOut | |
from vae.data.dataset_online_vae import TriplaneDataset | |
from einops import rearrange | |
from vae.utils.common_utils import instantiate_from_config | |
from Next3d.training_avatar_texture.triplane_generation import TriPlaneGenerator | |
import Next3d.legacy as legacy | |
from torch_utils import misc | |
import datetime | |
logger = get_logger(__name__, log_level="INFO") | |
def collate_fn(data): | |
model_names = [example["data_model_name"] for example in data] | |
zs = torch.cat([example["data_z"] for example in data], dim=0) | |
verts = torch.cat([example["data_vert"] for example in data], dim=0) | |
return { | |
'model_names': model_names, | |
'zs': zs, | |
'verts': verts | |
} | |
def rollout_fn(triplane): | |
triplane = rearrange(triplane, "b c f h w -> b f c h w") | |
b, f, c, h, w = triplane.shape | |
triplane = triplane.permute(0, 2, 3, 1, 4).reshape(-1, c, h, f * w) | |
return triplane | |
def unrollout_fn(triplane): | |
res = triplane.shape[-2] | |
ch = triplane.shape[1] | |
triplane = triplane.reshape(-1, ch // 3, res, 3, res).permute(0, 3, 1, 2, 4).reshape(-1, 3, ch, res, res) | |
triplane = rearrange(triplane, "b f c h w -> b c f h w") | |
return triplane | |
def triplane_generate(G_model, z, conditioning_params, std, mean, truncation_psi=0.7, truncation_cutoff=14): | |
w = G_model.mapping(z, conditioning_params, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) | |
triplane = G_model.synthesis(w, noise_mode='const') | |
triplane = (triplane - mean) / std | |
return triplane | |
def gan_model(gan_models, device, gan_model_base_dir): | |
gan_model_dict = gan_models | |
gan_model_load = {} | |
for model_name in gan_model_dict.keys(): | |
model_pkl = os.path.join(gan_model_base_dir, model_name + '.pkl') | |
with dnnlib.util.open_url(model_pkl) as f: | |
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore | |
G_new = TriPlaneGenerator(*G.init_args, **G.init_kwargs).eval().requires_grad_(False).to(device) | |
misc.copy_params_and_buffers(G, G_new, require_all=True) | |
G_new.neural_rendering_resolution = G.neural_rendering_resolution | |
G_new.rendering_kwargs = G.rendering_kwargs | |
gan_model_load[model_name] = G_new | |
return gan_model_load | |
def main(vae_config: str, | |
gan_model_config: str, | |
output_dir: str, | |
std_dir: str, | |
mean_dir: str, | |
conditioning_params_dir: str, | |
gan_model_base_dir: str, | |
train_data: Dict, | |
train_batch_size: int = 2, | |
max_train_steps: int = 500, | |
learning_rate: float = 3e-5, | |
scale_lr: bool = False, | |
lr_scheduler: str = "constant", | |
lr_warmup_steps: int = 0, | |
adam_beta1: float = 0.5, | |
adam_beta2: float = 0.9, | |
adam_weight_decay: float = 1e-2, | |
adam_epsilon: float = 1e-08, | |
max_grad_norm: float = 1.0, | |
gradient_accumulation_steps: int = 1, | |
gradient_checkpointing: bool = True, | |
checkpointing_steps: int = 500, | |
pretrained_model_path_zero123: str = None, | |
resume_from_checkpoint: Optional[str] = None, | |
mixed_precision: Optional[str] = "fp16", | |
use_8bit_adam: bool = False, | |
rollout: bool = False, | |
enable_xformers_memory_efficient_attention: bool = True, | |
seed: Optional[int] = None, ): | |
*_, config = inspect.getargvalues(inspect.currentframe()) | |
base_dir = output_dir | |
accelerator = Accelerator( | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
mixed_precision=mixed_precision, | |
) | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger.info(accelerator.state, main_process_only=False) | |
# If passed along, set the training seed now. | |
if seed is not None: | |
set_seed(seed) | |
if accelerator.is_main_process: | |
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
output_dir = os.path.join(output_dir, now) | |
os.makedirs(output_dir, exist_ok=True) | |
os.makedirs(f"{output_dir}/samples", exist_ok=True) | |
os.makedirs(f"{output_dir}/inv_latents", exist_ok=True) | |
OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) | |
config_vae = OmegaConf.load(vae_config) | |
if rollout: | |
vae = AutoencoderKLRollOut(ddconfig=config_vae['ddconfig'], lossconfig=config_vae['lossconfig'], embed_dim=8) | |
else: | |
vae = AutoencoderKL(ddconfig=config_vae['ddconfig'], lossconfig=config_vae['lossconfig'], embed_dim=8) | |
print(f"VAE total params = {len(list(vae.named_parameters()))} ") | |
if 'perceptual_weight' in config_vae['lossconfig']['params'].keys(): | |
config_vae['lossconfig']['params']['device'] = str(accelerator.device) | |
loss_fn = instantiate_from_config(config_vae['lossconfig']) | |
conditioning_params = torch.load(conditioning_params_dir).to(str(accelerator.device)) | |
data_std = torch.load(std_dir).to(str(accelerator.device)).reshape(1, -1, 1, 1, 1) | |
data_mean = torch.load(mean_dir).to(str(accelerator.device)).reshape(1, -1, 1, 1, 1) | |
# define the gan model | |
print("########## gan model load ##########") | |
config_gan_model = OmegaConf.load(gan_model_config) | |
gan_model_all = gan_model(config_gan_model['gan_models'], str(accelerator.device), gan_model_base_dir) | |
print("########## gan model loaded ##########") | |
if scale_lr: | |
learning_rate = ( | |
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes | |
) | |
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | |
if use_8bit_adam: | |
try: | |
import bitsandbytes as bnb | |
except ImportError: | |
raise ImportError( | |
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | |
) | |
optimizer_cls = bnb.optim.AdamW8bit | |
else: | |
optimizer_cls = torch.optim.AdamW | |
optimizer = optimizer_cls( | |
vae.parameters(), | |
lr=learning_rate, | |
betas=(adam_beta1, adam_beta2), | |
weight_decay=adam_weight_decay, | |
eps=adam_epsilon, | |
) | |
train_dataset = TriplaneDataset(**train_data) | |
# Preprocessing the dataset | |
# DataLoaders creation: | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=train_batch_size, collate_fn=collate_fn, shuffle=True, num_workers=2 | |
) | |
lr_scheduler = get_scheduler( | |
lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, | |
num_training_steps=max_train_steps * gradient_accumulation_steps, | |
) | |
vae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
vae, optimizer, train_dataloader, lr_scheduler | |
) | |
weight_dtype = torch.float32 | |
# Move text_encode and vae to gpu and cast to weight_dtype | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
vae.to(accelerator.device, dtype=weight_dtype) | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | |
# Afterwards we recalculate our number of training epochs | |
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) | |
# We need to initialize the trackers we use, and also store our configuration. | |
# The trackers initializes automatically on the main process. | |
if accelerator.is_main_process: | |
accelerator.init_trackers("trainvae", config=vars(args)) | |
# Train! | |
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataset)}") | |
logger.info(f" Num Epochs = {num_train_epochs}") | |
logger.info(f" Instantaneous batch size per device = {train_batch_size}") | |
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") | |
logger.info(f" Total optimization steps = {max_train_steps}") | |
global_step = 0 | |
first_epoch = 0 | |
# Potentially load in the weights and states from a previous save | |
if resume_from_checkpoint: | |
if resume_from_checkpoint != "latest": | |
path = os.path.basename(resume_from_checkpoint) | |
else: | |
# Get the most recent checkpoint | |
dirs = os.listdir(output_dir) | |
dirs = [d for d in dirs if d.startswith("checkpoint")] | |
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) | |
path = dirs[-1] | |
accelerator.print(f"Resuming from checkpoint {path}") | |
if resume_from_checkpoint != "latest": | |
accelerator.load_state(resume_from_checkpoint) | |
else: | |
accelerator.load_state(os.path.join(output_dir, path)) | |
global_step = int(path.split("-")[1]) | |
first_epoch = global_step // num_update_steps_per_epoch | |
resume_step = global_step % num_update_steps_per_epoch | |
else: | |
all_final_training_dirs = [] | |
dirs = os.listdir(base_dir) | |
if len(dirs) != 0: | |
dirs = [d for d in dirs if d.startswith("2024")] # specific years | |
if len(dirs) != 0: | |
base_resume_paths = [os.path.join(base_dir, d) for d in dirs] | |
for base_resume_path in base_resume_paths: | |
checkpoint_file_names = os.listdir(base_resume_path) | |
checkpoint_file_names = [d for d in checkpoint_file_names if d.startswith("checkpoint")] | |
if len(checkpoint_file_names) != 0: | |
for checkpoint_file_name in checkpoint_file_names: | |
final_training_dir = os.path.join(base_resume_path, checkpoint_file_name) | |
all_final_training_dirs.append(final_training_dir) | |
if len(all_final_training_dirs) != 0: | |
sorted_all_final_training_dirs = sorted(all_final_training_dirs, key=lambda x: int(x.split("-")[1])) | |
latest_dir = sorted_all_final_training_dirs[-1] | |
path = os.path.basename( latest_dir) | |
accelerator.print(f"Resuming from checkpoint {path}") | |
accelerator.load_state(latest_dir) | |
global_step = int(path.split("-")[1]) | |
first_epoch = global_step // num_update_steps_per_epoch | |
resume_step = global_step % num_update_steps_per_epoch | |
else: | |
accelerator.print(f"Training from start") | |
else: | |
accelerator.print(f"Training from start") | |
else: | |
accelerator.print(f"Training from start") | |
# Only show the progress bar once on each machine. | |
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process) | |
progress_bar.set_description("Steps") | |
for epoch in range(first_epoch, num_train_epochs): | |
vae.train() | |
train_loss = 0.0 | |
for step, batch in enumerate(train_dataloader): | |
# if resume_from_checkpoint and epoch == first_epoch and step < resume_step: | |
# print(epoch) | |
# print(first_epoch) | |
# print(step) | |
# if step % gradient_accumulation_steps == 0: | |
# progress_bar.update(1) | |
# continue | |
with accelerator.accumulate(vae): | |
# Convert images to latent space | |
z_values = batch["zs"].to(weight_dtype) | |
model_names = batch["model_names"] | |
triplane_values = [] | |
with torch.no_grad(): | |
for z_id in range(z_values.shape[0]): | |
z_value = z_values[z_id].unsqueeze(0) | |
model_name = model_names[z_id] | |
triplane_value = triplane_generate(gan_model_all[model_name], z_value, | |
conditioning_params, data_std, data_mean) | |
triplane_values.append(triplane_value) | |
triplane_values = torch.cat(triplane_values, dim=0) | |
vert_values = batch["verts"].to(weight_dtype) | |
triplane_values = rearrange(triplane_values, "b f c h w -> b c f h w") | |
if rollout: | |
triplane_values_roll = rollout_fn(triplane_values.clone()) | |
reconstructions, posterior = vae(triplane_values_roll) | |
reconstructions_unroll = unrollout_fn(reconstructions) | |
loss, log_dict_ae = loss_fn(triplane_values, reconstructions_unroll, posterior, vert_values, | |
split="train") | |
else: | |
reconstructions, posterior = vae(triplane_values) | |
loss, log_dict_ae = loss_fn(triplane_values, reconstructions, posterior, vert_values, | |
split="train") | |
accelerator.backward(loss) | |
if accelerator.sync_gradients: | |
accelerator.clip_grad_norm_(vae.parameters(), max_grad_norm) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if accelerator.sync_gradients: | |
progress_bar.update(1) | |
global_step += 1 | |
accelerator.log({"train_loss": train_loss}, step=global_step) | |
train_loss = 0.0 | |
if global_step % checkpointing_steps == 0: | |
if accelerator.is_main_process: | |
save_path = os.path.join(output_dir, f"checkpoint-{global_step}") | |
accelerator.save_state(save_path) | |
logger.info(f"Saved state to {save_path}") | |
# logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | |
logs = log_dict_ae | |
progress_bar.set_postfix(**logs) | |
accelerator.log(logs, step=global_step) | |
if global_step >= max_train_steps: | |
break | |
accelerator.wait_for_everyone() | |
accelerator.end_training() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="./configs/triplane_vae.yaml") | |
args = parser.parse_args() | |
main(**OmegaConf.load(args.config)) | |