Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import subprocess | |
| def setup_for_distributed(is_master): | |
| """ | |
| This function disables printing when not in master process | |
| """ | |
| import builtins as __builtin__ | |
| builtin_print = __builtin__.print | |
| def print(*args, **kwargs): | |
| force = kwargs.pop('force', False) | |
| if is_master or force: | |
| builtin_print(*args, **kwargs) | |
| __builtin__.print = print | |
| def init_distributed_mode(args): | |
| if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: | |
| args.rank = int(os.environ["RANK"]) | |
| args.world_size = int(os.environ['WORLD_SIZE']) | |
| args.gpu = int(os.environ['LOCAL_RANK']) | |
| args.dist_url = 'env://' | |
| os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) | |
| elif 'SLURM_PROCID' in os.environ: | |
| proc_id = int(os.environ['SLURM_PROCID']) | |
| ntasks = int(os.environ['SLURM_NTASKS']) | |
| node_list = os.environ['SLURM_NODELIST'] | |
| num_gpus = torch.cuda.device_count() | |
| addr = subprocess.getoutput( | |
| 'scontrol show hostname {} | head -n1'.format(node_list)) | |
| os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') | |
| os.environ['MASTER_ADDR'] = addr | |
| os.environ['WORLD_SIZE'] = str(ntasks) | |
| os.environ['RANK'] = str(proc_id) | |
| os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) | |
| os.environ['LOCAL_SIZE'] = str(num_gpus) | |
| args.dist_url = 'env://' | |
| args.world_size = ntasks | |
| args.rank = proc_id | |
| args.gpu = proc_id % num_gpus | |
| else: | |
| print('Not using distributed mode') | |
| args.distributed = False | |
| return | |
| args.distributed = True | |
| torch.cuda.set_device(args.gpu) | |
| args.dist_backend = 'nccl' | |
| print('| distributed init (rank {}): {}'.format( | |
| args.rank, args.dist_url), flush=True) | |
| torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, | |
| world_size=args.world_size, rank=args.rank) | |
| torch.distributed.barrier() | |
| setup_for_distributed(args.rank == 0) | |