Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
from shutil import copyfile | |
import torch.distributed as dist | |
import torch | |
import importlib | |
import datetime | |
from utils.dist_utils import ( | |
get_world_size, | |
) | |
from omegaconf import OmegaConf | |
from utils.utils import seed_all | |
parser = argparse.ArgumentParser(description='VFI') | |
parser.add_argument('-c', '--config', type=str) | |
parser.add_argument('-p', '--port', default='23455', type=str) | |
parser.add_argument('--local_rank', default='0') | |
args = parser.parse_args() | |
def main_worker(rank, config): | |
if 'local_rank' not in config: | |
config['local_rank'] = config['global_rank'] = rank | |
if torch.cuda.is_available(): | |
print(f'Rank {rank} is available') | |
config['device'] = f"cuda:{rank}" | |
if config['distributed']: | |
dist.init_process_group(backend='nccl', | |
timeout=datetime.timedelta(seconds=5400)) | |
else: | |
config['device'] = 'cpu' | |
cfg_name = os.path.basename(args.config).split('.')[0] | |
config['exp_name'] = cfg_name + '_' + config['exp_name'] | |
config['save_dir'] = os.path.join(config['save_dir'], config['exp_name']) | |
if (not config['distributed']) or rank == 0: | |
os.makedirs(config['save_dir'], exist_ok=True) | |
os.makedirs(f'{config["save_dir"]}/ckpts', exist_ok=True) | |
config_path = os.path.join(config['save_dir'], | |
args.config.split('/')[-1]) | |
if not os.path.isfile(config_path): | |
copyfile(args.config, config_path) | |
print('[**] create folder {}'.format(config['save_dir'])) | |
trainer_name = config.get('trainer_type', 'base_trainer') | |
print(f'using GPU {rank} for training') | |
if rank == 0: | |
print(trainer_name) | |
trainer_pack = importlib.import_module('trainers.' + trainer_name) | |
trainer = trainer_pack.Trainer(config) | |
trainer.train() | |
if __name__ == "__main__": | |
torch.backends.cudnn.benchmark = True | |
cfg = OmegaConf.load(args.config) | |
seed_all(cfg.seed) | |
rank = int(args.local_rank) | |
torch.cuda.set_device(torch.device(f'cuda:{rank}')) | |
# setting distributed cfgurations | |
cfg['world_size'] = get_world_size() | |
cfg['local_rank'] = rank | |
if rank == 0: | |
print('world_size: ', cfg['world_size']) | |
main_worker(rank, cfg) | |