Spaces:
Runtime error
Runtime error
| import importlib | |
| import torch | |
| import torch.distributed as dist | |
| from .avg_meter import AverageMeter | |
| from collections import defaultdict, OrderedDict | |
| import os | |
| import socket | |
| from mmcv.utils import collect_env as collect_base_env | |
| try: | |
| from mmcv.utils import get_git_hash | |
| except: | |
| from mmengine.utils import get_git_hash | |
| #import mono.mmseg as mmseg | |
| # import mmseg | |
| import time | |
| import datetime | |
| import logging | |
| def main_process() -> bool: | |
| return get_rank() == 0 | |
| #return not cfg.distributed or \ | |
| # (cfg.distributed and cfg.local_rank == 0) | |
| def get_world_size() -> int: | |
| if not dist.is_available(): | |
| return 1 | |
| if not dist.is_initialized(): | |
| return 1 | |
| return dist.get_world_size() | |
| def get_rank() -> int: | |
| if not dist.is_available(): | |
| return 0 | |
| if not dist.is_initialized(): | |
| return 0 | |
| return dist.get_rank() | |
| def _find_free_port(): | |
| # refer to https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 | |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| # Binding to port 0 will cause the OS to find an available port for us | |
| sock.bind(('', 0)) | |
| port = sock.getsockname()[1] | |
| sock.close() | |
| # NOTE: there is still a chance the port could be taken by other processes. | |
| return port | |
| def _is_free_port(port): | |
| ips = socket.gethostbyname_ex(socket.gethostname())[-1] | |
| ips.append('localhost') | |
| with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| return all(s.connect_ex((ip, port)) != 0 for ip in ips) | |
| # def collect_env(): | |
| # """Collect the information of the running environments.""" | |
| # env_info = collect_base_env() | |
| # env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' | |
| # return env_info | |
| def init_env(launcher, cfg): | |
| """Initialize distributed training environment. | |
| If argument ``cfg.dist_params.dist_url`` is specified as 'env://', then the master port will be system | |
| environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system | |
| environment variable, then a default port ``29500`` will be used. | |
| """ | |
| if launcher == 'slurm': | |
| _init_dist_slurm(cfg) | |
| elif launcher == 'ror': | |
| _init_dist_ror(cfg) | |
| elif launcher == 'None': | |
| _init_none_dist(cfg) | |
| else: | |
| raise RuntimeError(f'{cfg.launcher} has not been supported!') | |
| def _init_none_dist(cfg): | |
| cfg.dist_params.num_gpus_per_node = 1 | |
| cfg.dist_params.world_size = 1 | |
| cfg.dist_params.nnodes = 1 | |
| cfg.dist_params.node_rank = 0 | |
| cfg.dist_params.global_rank = 0 | |
| cfg.dist_params.local_rank = 0 | |
| os.environ["WORLD_SIZE"] = str(1) | |
| def _init_dist_ror(cfg): | |
| from ac2.ror.comm import get_local_rank, get_world_rank, get_local_size, get_node_rank, get_world_size | |
| cfg.dist_params.num_gpus_per_node = get_local_size() | |
| cfg.dist_params.world_size = get_world_size() | |
| cfg.dist_params.nnodes = (get_world_size()) // (get_local_size()) | |
| cfg.dist_params.node_rank = get_node_rank() | |
| cfg.dist_params.global_rank = get_world_rank() | |
| cfg.dist_params.local_rank = get_local_rank() | |
| os.environ["WORLD_SIZE"] = str(get_world_size()) | |
| def _init_dist_slurm(cfg): | |
| if 'NNODES' not in os.environ: | |
| os.environ['NNODES'] = str(cfg.dist_params.nnodes) | |
| if 'NODE_RANK' not in os.environ: | |
| os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank) | |
| #cfg.dist_params. | |
| num_gpus = torch.cuda.device_count() | |
| world_size = int(os.environ['NNODES']) * num_gpus | |
| os.environ['WORLD_SIZE'] = str(world_size) | |
| # config port | |
| if 'MASTER_PORT' in os.environ: | |
| master_port = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable | |
| else: | |
| # if torch.distributed default port(29500) is available | |
| # then use it, else find a free port | |
| if _is_free_port(16500): | |
| master_port = '16500' | |
| else: | |
| master_port = str(_find_free_port()) | |
| os.environ['MASTER_PORT'] = master_port | |
| # config addr | |
| if 'MASTER_ADDR' in os.environ: | |
| master_addr = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable | |
| # elif cfg.dist_params.dist_url is not None: | |
| # master_addr = ':'.join(str(cfg.dist_params.dist_url).split(':')[:2]) | |
| else: | |
| master_addr = '127.0.0.1' #'tcp://127.0.0.1' | |
| os.environ['MASTER_ADDR'] = master_addr | |
| # set dist_url to 'env://' | |
| cfg.dist_params.dist_url = 'env://' #f"{master_addr}:{master_port}" | |
| cfg.dist_params.num_gpus_per_node = num_gpus | |
| cfg.dist_params.world_size = world_size | |
| cfg.dist_params.nnodes = int(os.environ['NNODES']) | |
| cfg.dist_params.node_rank = int(os.environ['NODE_RANK']) | |
| # if int(os.environ['NNODES']) > 1 and cfg.dist_params.dist_url.startswith("file://"): | |
| # raise Warning("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://") | |
| def get_func(func_name): | |
| """ | |
| Helper to return a function object by name. func_name must identify | |
| a function in this module or the path to a function relative to the base | |
| module. | |
| @ func_name: function name. | |
| """ | |
| if func_name == '': | |
| return None | |
| try: | |
| parts = func_name.split('.') | |
| # Refers to a function in this module | |
| if len(parts) == 1: | |
| return globals()[parts[0]] | |
| # Otherwise, assume we're referencing a module under modeling | |
| module_name = '.'.join(parts[:-1]) | |
| module = importlib.import_module(module_name) | |
| return getattr(module, parts[-1]) | |
| except: | |
| raise RuntimeError(f'Failed to find function: {func_name}') | |
| class Timer(object): | |
| """A simple timer.""" | |
| def __init__(self): | |
| self.reset() | |
| def tic(self): | |
| # using time.time instead of time.clock because time time.clock | |
| # does not normalize for multithreading | |
| self.start_time = time.time() | |
| def toc(self, average=True): | |
| self.diff = time.time() - self.start_time | |
| self.total_time += self.diff | |
| self.calls += 1 | |
| self.average_time = self.total_time / self.calls | |
| if average: | |
| return self.average_time | |
| else: | |
| return self.diff | |
| def reset(self): | |
| self.total_time = 0. | |
| self.calls = 0 | |
| self.start_time = 0. | |
| self.diff = 0. | |
| self.average_time = 0. | |
| class TrainingStats(object): | |
| """Track vital training statistics.""" | |
| def __init__(self, log_period, tensorboard_logger=None): | |
| self.log_period = log_period | |
| self.tblogger = tensorboard_logger | |
| self.tb_ignored_keys = ['iter', 'eta', 'epoch', 'time'] | |
| self.iter_timer = Timer() | |
| # Window size for smoothing tracked values (with median filtering) | |
| self.filter_size = log_period | |
| def create_smoothed_value(): | |
| return AverageMeter() | |
| self.smoothed_losses = defaultdict(create_smoothed_value) | |
| #self.smoothed_metrics = defaultdict(create_smoothed_value) | |
| #self.smoothed_total_loss = AverageMeter() | |
| def IterTic(self): | |
| self.iter_timer.tic() | |
| def IterToc(self): | |
| return self.iter_timer.toc(average=False) | |
| def reset_iter_time(self): | |
| self.iter_timer.reset() | |
| def update_iter_stats(self, losses_dict): | |
| """Update tracked iteration statistics.""" | |
| for k, v in losses_dict.items(): | |
| self.smoothed_losses[k].update(float(v), 1) | |
| def log_iter_stats(self, cur_iter, optimizer, max_iters, val_err={}): | |
| """Log the tracked statistics.""" | |
| if (cur_iter % self.log_period == 0): | |
| stats = self.get_stats(cur_iter, optimizer, max_iters, val_err) | |
| log_stats(stats) | |
| if self.tblogger: | |
| self.tb_log_stats(stats, cur_iter) | |
| for k, v in self.smoothed_losses.items(): | |
| v.reset() | |
| def tb_log_stats(self, stats, cur_iter): | |
| """Log the tracked statistics to tensorboard""" | |
| for k in stats: | |
| # ignore some logs | |
| if k not in self.tb_ignored_keys: | |
| v = stats[k] | |
| if isinstance(v, dict): | |
| self.tb_log_stats(v, cur_iter) | |
| else: | |
| self.tblogger.add_scalar(k, v, cur_iter) | |
| def get_stats(self, cur_iter, optimizer, max_iters, val_err = {}): | |
| eta_seconds = self.iter_timer.average_time * (max_iters - cur_iter) | |
| eta = str(datetime.timedelta(seconds=int(eta_seconds))) | |
| stats = OrderedDict( | |
| iter=cur_iter, # 1-indexed | |
| time=self.iter_timer.average_time, | |
| eta=eta, | |
| ) | |
| optimizer_state_dict = optimizer.state_dict() | |
| lr = {} | |
| for i in range(len(optimizer_state_dict['param_groups'])): | |
| lr_name = 'group%d_lr' % i | |
| lr[lr_name] = optimizer_state_dict['param_groups'][i]['lr'] | |
| stats['lr'] = OrderedDict(lr) | |
| for k, v in self.smoothed_losses.items(): | |
| stats[k] = v.avg | |
| stats['val_err'] = OrderedDict(val_err) | |
| stats['max_iters'] = max_iters | |
| return stats | |
| def reduce_dict(input_dict, average=True): | |
| """ | |
| Reduce the values in the dictionary from all processes so that process with rank | |
| 0 has the reduced results. | |
| Args: | |
| @input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. | |
| @average (bool): whether to do average or sum | |
| Returns: | |
| a dict with the same keys as input_dict, after reduction. | |
| """ | |
| world_size = get_world_size() | |
| if world_size < 2: | |
| return input_dict | |
| with torch.no_grad(): | |
| names = [] | |
| values = [] | |
| # sort the keys so that they are consistent across processes | |
| for k in sorted(input_dict.keys()): | |
| names.append(k) | |
| values.append(input_dict[k]) | |
| values = torch.stack(values, dim=0) | |
| dist.reduce(values, dst=0) | |
| if dist.get_rank() == 0 and average: | |
| # only main process gets accumulated, so only divide by | |
| # world_size in this case | |
| values /= world_size | |
| reduced_dict = {k: v for k, v in zip(names, values)} | |
| return reduced_dict | |
| def log_stats(stats): | |
| logger = logging.getLogger() | |
| """Log training statistics to terminal""" | |
| lines = "[Step %d/%d]\n" % ( | |
| stats['iter'], stats['max_iters']) | |
| lines += "\t\tloss: %.3f, time: %.6f, eta: %s\n" % ( | |
| stats['total_loss'], stats['time'], stats['eta']) | |
| # log loss | |
| lines += "\t\t" | |
| for k, v in stats.items(): | |
| if 'loss' in k.lower() and 'total_loss' not in k.lower(): | |
| lines += "%s: %.3f" % (k, v) + ", " | |
| lines = lines[:-3] | |
| lines += '\n' | |
| # validate criteria | |
| lines += "\t\tlast val err:" + ", ".join("%s: %.6f" % (k, v) for k, v in stats['val_err'].items()) + ", " | |
| lines += '\n' | |
| # lr in different groups | |
| lines += "\t\t" + ", ".join("%s: %.8f" % (k, v) for k, v in stats['lr'].items()) | |
| lines += '\n' | |
| logger.info(lines[:-1]) # remove last new linen_pxl | |