File size: 3,991 Bytes
8c212a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# python3.7
"""Misc utility functions."""
import os
import sys
import subprocess
from importlib import import_module
import argparse
from easydict import EasyDict
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
__all__ = [
'init_dist', 'bool_parser', 'DictAction', 'parse_config', 'update_config'
]
def init_dist(launcher, backend='nccl', **kwargs):
"""Initializes distributed environment."""
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
elif launcher == 'slurm':
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1')
port = os.environ.get('PORT', 29500)
os.environ['MASTER_PORT'] = str(port)
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
else:
raise NotImplementedError(f'Not implemented launcher type: '
f'`{launcher}`!')
def bool_parser(arg):
"""Parses an argument to boolean."""
if isinstance(arg, bool):
return arg
if arg.lower() in ['1', 'true', 't', 'yes', 'y']:
return True
if arg.lower() in ['0', 'false', 'f', 'no', 'n']:
return False
raise argparse.ArgumentTypeError(f'`{arg}` cannot be converted to boolean!')
class DictAction(argparse.Action):
"""Argparse action to split an argument into key-value.
NOTE: This class is borrowed from
https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
"""
@staticmethod
def _parse_int_float_bool(val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ['true', 'false']:
return val.lower() == 'true'
return val
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split('=', maxsplit=1)
val = [self._parse_int_float_bool(v) for v in val.split(',')]
if len(val) == 1:
val = val[0]
options[key] = val
setattr(namespace, self.dest, options)
def parse_config(config_file):
"""Parses configuration from python file."""
assert os.path.isfile(config_file)
directory = os.path.dirname(config_file)
filename = os.path.basename(config_file)
module_name, extension = os.path.splitext(filename)
assert extension == '.py'
sys.path.insert(0, directory)
module = import_module(module_name)
sys.path.pop(0)
config = EasyDict()
for key, value in module.__dict__.items():
if key.startswith('__'):
continue
config[key] = value
del sys.modules[module_name]
return config
def update_config(config, new_config):
"""Updates configuration in a hierarchical level.
For key-value pair {'a.b.c.d': v} in `new_config`, the `config` will be
updated by
config['a']['b']['c']['d'] = v
"""
if new_config is None:
return config
assert isinstance(config, dict)
assert isinstance(new_config, dict)
for key, val in new_config.items():
hierarchical_keys = key.split('.')
temp = config
for sub_key in hierarchical_keys[:-1]:
temp = temp[sub_key]
temp[hierarchical_keys[-1]] = val
return config
|