Spaces:
Runtime error
Runtime error
import argparse | |
from datetime import datetime | |
import random | |
import os | |
import time | |
import multiprocessing | |
# Set multiprocessing start method to 'spawn' to avoid CUDA initialization issues in forked processes | |
multiprocessing.set_start_method('spawn', force=True) | |
from tqdm.auto import tqdm # Progress bar | |
import numpy as np | |
from omegaconf import OmegaConf | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
from torch.optim.lr_scheduler import SequentialLR, LambdaLR, CosineAnnealingLR, ExponentialLR # Importing CosineAnnealingLR scheduler | |
import torch.nn.functional as F | |
from accelerate import Accelerator, DistributedDataParallelKwargs | |
from accelerate.utils import set_seed # Removed get_scheduler import | |
from peft import get_peft_model, LoraConfig | |
from modeling import VMemModel | |
from modeling.modules.autoencoder import AutoEncoder | |
from modeling.sampling import DDPMDiscretization, DiscreteDenoiser, create_samplers | |
from modeling.modules.conditioner import CLIPConditioner | |
from utils.training_utils import DiffusionTrainer, load_pretrained_model | |
from data.dataset import RealEstatePoseImageSevaDataset | |
# set random seed for reproducibility | |
torch.manual_seed(42) | |
random.seed(42) | |
np.random.seed(42) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Train a model') | |
parser.add_argument('--config', type=str, default="", required=True, help='Path to the config file') | |
args = parser.parse_args() | |
return args | |
def generate_current_datetime(): | |
return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
def prepare_model(unet, config): | |
assert isinstance(unet, VMemModel), "unet should be an instance of VMemModel" | |
if config.training.lora_flag: | |
target_modules = [] | |
for name, param in unet.named_parameters(): | |
# # if ("temporal" in name or "transformer" in name) and "norm" not in name: | |
print(name) | |
if ("transformer" in name or "emb" in name or "layers" in name) \ | |
and "norm" not in name and "in_layers.0" not in name and "out_layers.0" not in name: | |
# print(name) | |
name = name.replace(".weight", "") | |
name = name.replace(".bias", "") | |
if name not in target_modules: | |
target_modules.append(str(name)) | |
lora_config = LoraConfig( | |
r=config.training.lora_r, | |
lora_alpha=config.training.lora_alpha, | |
target_modules=target_modules, | |
lora_dropout=config.training.lora_dropout, | |
# bias="none", | |
) | |
lora_config.target_modules = target_modules | |
unet = get_peft_model(unet, lora_config) | |
# for name, param in unet.named_parameters(): | |
# if "camera" in name or "control" in name or "context" in name or "epipolar" in name or "appearance" in name: | |
# print(name) | |
# param.requires_grad = True | |
unet.print_trainable_parameters() | |
else: | |
for name, param in unet.named_parameters(): | |
param.requires_grad = True | |
print("trainable parameters percentage: ", np.sum([p.numel() for p in unet.parameters() if p.requires_grad])/np.sum([p.numel() for p in unet.parameters()])) | |
return unet | |
def main(): | |
args = parse_args() | |
config_path = args.config | |
config = OmegaConf.load(config_path) | |
# Load the configuration | |
num_epochs = config.training.num_epochs | |
batch_size = config.training.batch_size | |
learning_rate = config.training.learning_rate | |
gradient_accumulation_steps = config.training.gradient_accumulation_steps | |
num_workers = config.training.num_workers | |
warmup_epochs = config.training.warmup_epochs | |
max_grad_norm = config.training.max_grad_norm | |
validation_interval = config.training.validation_interval | |
visualization_flag = config.training.visualization_flag | |
visualize_every = config.training.visualize_every | |
random_seed = config.training.random_seed | |
save_flag = config.training.save_flag | |
use_wandb = config.training.use_wandb | |
samples_dir = config.training.samples_dir | |
weights_save_dir = config.training.weights_save_dir | |
resume = config.training.resume | |
exp_id = generate_current_datetime() | |
if visualization_flag: | |
run_visualization_dir = f"{samples_dir}/{exp_id}" | |
os.makedirs(run_visualization_dir, exist_ok=True) | |
else: | |
run_visualization_dir = None | |
if save_flag: | |
run_weights_save_dir = f"{weights_save_dir}/{exp_id}" | |
os.makedirs(run_weights_save_dir, exist_ok=True) | |
else: | |
run_weights_save_dir = None | |
accelerator = Accelerator( | |
mixed_precision="fp16", | |
gradient_accumulation_steps=gradient_accumulation_steps, | |
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=False)], | |
) | |
num_gpus = accelerator.num_processes | |
if random_seed is not None: | |
set_seed(random_seed, device_specific=True) | |
device = accelerator.device | |
model = load_pretrained_model(cache_dir=config.model.cache_dir, device=device) | |
model = prepare_model(model, config) | |
if resume: | |
model.load_state_dict(torch.load(resume, map_location='cpu'), strict=False) | |
torch.cuda.empty_cache() | |
# model = model.to(device) | |
# time.sleep(100*3600) | |
train_dataset = RealEstatePoseImageSevaDataset(rgb_data_dir=config.dataset.realestate10k.rgb_data_dir, | |
meta_info_dir=config.dataset.realestate10k.meta_info_dir, | |
num_sample_per_episode=config.dataset.realestate10k.num_sample_per_episode, | |
mode='train') | |
val_dataset = RealEstatePoseImageSevaDataset(rgb_data_dir=config.dataset.realestate10k.rgb_data_dir, | |
meta_info_dir=config.dataset.realestate10k.meta_info_dir, | |
num_sample_per_episode=config.dataset.realestate10k.val_num_sample_per_episode, | |
mode='test') | |
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, multiprocessing_context='spawn') | |
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, multiprocessing_context='spawn') | |
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=config.training.weight_decay) | |
train_steps_per_epoch = len(train_dataloader) | |
total_train_steps = num_epochs * train_steps_per_epoch | |
warmup_steps = warmup_epochs * train_steps_per_epoch | |
lr_scheduler = CosineAnnealingLR( | |
optimizer, T_max=total_train_steps - warmup_steps, eta_min=0 | |
) | |
# lr_scheduler = ExponentialLR(optimizer, gamma=gamma) | |
if warmup_epochs > 0: | |
def warmup_lambda(current_step): | |
return float(current_step) / float(max(1, warmup_steps)) | |
warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda) | |
# Combine the schedulers using SequentialLR | |
lr_scheduler = SequentialLR( | |
optimizer, schedulers=[warmup_scheduler, lr_scheduler], milestones=[warmup_steps] | |
) | |
vae = AutoEncoder(chunk_size=1).to(device) | |
vae.eval() | |
conditioner = CLIPConditioner().to(device) | |
discretization = DDPMDiscretization() | |
denoiser = DiscreteDenoiser(discretization=discretization, num_idx=1000, device=device) | |
sampler = create_samplers(guider_types=config.training.guider_types, | |
discretization=discretization, | |
num_frames=config.model.num_frames, | |
num_steps=config.training.inference_num_steps, | |
cfg_min=config.training.cfg_min, | |
device=device) | |
(model, | |
vae, | |
train_dataloader, | |
val_dataloader, | |
optimizer, | |
lr_scheduler) = accelerator.prepare( | |
model, | |
vae, | |
train_dataloader, | |
val_dataloader, | |
optimizer, | |
lr_scheduler, | |
) | |
trainer = DiffusionTrainer(network=model, | |
ae=vae, | |
conditioner=conditioner, | |
denoiser=denoiser, | |
sampler=sampler, | |
discretization=discretization, | |
cfg=config.training.cfg, | |
optimizer=optimizer, | |
lr_scheduler=lr_scheduler, | |
ema_decay=config.training.ema_decay, | |
device=device, | |
accelerator=accelerator, | |
max_grad_norm=max_grad_norm, | |
save_flag=save_flag, | |
visualize_flag=visualization_flag) | |
trainer.train(train_dataloader, | |
num_epochs, | |
unconditional_prob=config.training.uncond_prob, | |
log_every=10, | |
validation_dataloader=val_dataloader, | |
validation_interval=validation_interval, | |
save_dir=run_weights_save_dir, | |
save_interval=config.training.save_every, | |
visualize_every=visualize_every, | |
visualize_dir=run_visualization_dir, | |
use_wandb=use_wandb) | |
if __name__ == "__main__": | |
main() | |