Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import json | |
import os | |
import torch | |
import torch.multiprocessing as mp | |
from data_loaders.get_data import get_dataset_loader, load_local_data | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.utils.tensorboard import SummaryWriter | |
from train.train_platforms import ClearmlPlatform, NoPlatform, TensorboardPlatform | |
from train.training_loop import TrainLoop | |
from utils.diff_parser_utils import train_args | |
from utils.misc import cleanup, fixseed, setup_dist | |
from utils.model_util import create_model_and_diffusion | |
def main(rank: int, world_size: int): | |
args = train_args() | |
fixseed(args.seed) | |
train_platform_type = eval(args.train_platform_type) | |
train_platform = train_platform_type(args.save_dir) | |
train_platform.report_args(args, name="Args") | |
setup_dist(args.device) | |
if rank == 0: | |
if args.save_dir is None: | |
raise FileNotFoundError("save_dir was not specified.") | |
elif os.path.exists(args.save_dir) and not args.overwrite: | |
raise FileExistsError("save_dir [{}] already exists.".format(args.save_dir)) | |
elif not os.path.exists(args.save_dir): | |
os.makedirs(args.save_dir) | |
args_path = os.path.join(args.save_dir, "args.json") | |
with open(args_path, "w") as fw: | |
json.dump(vars(args), fw, indent=4, sort_keys=True) | |
if not os.path.exists(args.data_root): | |
args.data_root = args.data_root.replace("/home/", "/derived/") | |
data_dict = load_local_data(args.data_root, audio_per_frame=1600) | |
print("creating data loader...") | |
data = get_dataset_loader(args=args, data_dict=data_dict) | |
print("creating logger...") | |
writer = SummaryWriter(args.save_dir) | |
print("creating model and diffusion...") | |
model, diffusion = create_model_and_diffusion(args, split_type="train") | |
model.to(rank) | |
if world_size > 1: | |
model = DDP( | |
model, device_ids=[rank], output_device=rank, find_unused_parameters=True | |
) | |
params = ( | |
model.module.parameters_w_grad() | |
if world_size > 1 | |
else model.parameters_w_grad() | |
) | |
print("Total params: %.2fM" % (sum(p.numel() for p in params) / 1000000.0)) | |
print("Training...") | |
TrainLoop( | |
args, train_platform, model, diffusion, data, writer, rank, world_size | |
).run_loop() | |
train_platform.close() | |
cleanup() | |
if __name__ == "__main__": | |
world_size = torch.cuda.device_count() | |
print(f"using {world_size} gpus") | |
if world_size > 1: | |
mp.spawn(main, args=(world_size,), nprocs=world_size, join=True) | |
else: | |
main(rank=0, world_size=1) | |