audio2photoreal / train /train_diffusion.py
lybxin's picture
Upload folder using huggingface_hub
66b7c56 verified
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import json
import os
import torch
import torch.multiprocessing as mp
from data_loaders.get_data import get_dataset_loader, load_local_data
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from train.train_platforms import ClearmlPlatform, NoPlatform, TensorboardPlatform
from train.training_loop import TrainLoop
from utils.diff_parser_utils import train_args
from utils.misc import cleanup, fixseed, setup_dist
from utils.model_util import create_model_and_diffusion
def main(rank: int, world_size: int):
args = train_args()
fixseed(args.seed)
train_platform_type = eval(args.train_platform_type)
train_platform = train_platform_type(args.save_dir)
train_platform.report_args(args, name="Args")
setup_dist(args.device)
if rank == 0:
if args.save_dir is None:
raise FileNotFoundError("save_dir was not specified.")
elif os.path.exists(args.save_dir) and not args.overwrite:
raise FileExistsError("save_dir [{}] already exists.".format(args.save_dir))
elif not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
args_path = os.path.join(args.save_dir, "args.json")
with open(args_path, "w") as fw:
json.dump(vars(args), fw, indent=4, sort_keys=True)
if not os.path.exists(args.data_root):
args.data_root = args.data_root.replace("/home/", "/derived/")
data_dict = load_local_data(args.data_root, audio_per_frame=1600)
print("creating data loader...")
data = get_dataset_loader(args=args, data_dict=data_dict)
print("creating logger...")
writer = SummaryWriter(args.save_dir)
print("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(args, split_type="train")
model.to(rank)
if world_size > 1:
model = DDP(
model, device_ids=[rank], output_device=rank, find_unused_parameters=True
)
params = (
model.module.parameters_w_grad()
if world_size > 1
else model.parameters_w_grad()
)
print("Total params: %.2fM" % (sum(p.numel() for p in params) / 1000000.0))
print("Training...")
TrainLoop(
args, train_platform, model, diffusion, data, writer, rank, world_size
).run_loop()
train_platform.close()
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"using {world_size} gpus")
if world_size > 1:
mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)
else:
main(rank=0, world_size=1)