Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import math | |
import os | |
import random | |
import re | |
from datetime import timedelta | |
from typing import Optional | |
import hydra | |
import numpy as np | |
import omegaconf | |
import torch | |
import torch.distributed as dist | |
from iopath.common.file_io import g_pathmgr | |
from omegaconf import OmegaConf | |
def multiply_all(*args): | |
return np.prod(np.array(args)).item() | |
def collect_dict_keys(config): | |
"""This function recursively iterates through a dataset configuration, and collect all the dict_key that are defined""" | |
val_keys = [] | |
# If the this config points to the collate function, then it has a key | |
if "_target_" in config and re.match(r".*collate_fn.*", config["_target_"]): | |
val_keys.append(config["dict_key"]) | |
else: | |
# Recursively proceed | |
for v in config.values(): | |
if isinstance(v, type(config)): | |
val_keys.extend(collect_dict_keys(v)) | |
elif isinstance(v, omegaconf.listconfig.ListConfig): | |
for item in v: | |
if isinstance(item, type(config)): | |
val_keys.extend(collect_dict_keys(item)) | |
return val_keys | |
class Phase: | |
TRAIN = "train" | |
VAL = "val" | |
def register_omegaconf_resolvers(): | |
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) | |
OmegaConf.register_new_resolver("get_class", hydra.utils.get_class) | |
OmegaConf.register_new_resolver("add", lambda x, y: x + y) | |
OmegaConf.register_new_resolver("times", multiply_all) | |
OmegaConf.register_new_resolver("divide", lambda x, y: x / y) | |
OmegaConf.register_new_resolver("pow", lambda x, y: x**y) | |
OmegaConf.register_new_resolver("subtract", lambda x, y: x - y) | |
OmegaConf.register_new_resolver("range", lambda x: list(range(x))) | |
OmegaConf.register_new_resolver("int", lambda x: int(x)) | |
OmegaConf.register_new_resolver("ceil_int", lambda x: int(math.ceil(x))) | |
OmegaConf.register_new_resolver("merge", lambda *x: OmegaConf.merge(*x)) | |
def setup_distributed_backend(backend, timeout_mins): | |
""" | |
Initialize torch.distributed and set the CUDA device. | |
Expects environment variables to be set as per | |
https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization | |
along with the environ variable "LOCAL_RANK" which is used to set the CUDA device. | |
""" | |
# enable TORCH_NCCL_ASYNC_ERROR_HANDLING to ensure dist nccl ops time out after timeout_mins | |
# of waiting | |
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" | |
logging.info(f"Setting up torch.distributed with a timeout of {timeout_mins} mins") | |
dist.init_process_group(backend=backend, timeout=timedelta(minutes=timeout_mins)) | |
return dist.get_rank() | |
def get_machine_local_and_dist_rank(): | |
""" | |
Get the distributed and local rank of the current gpu. | |
""" | |
local_rank = int(os.environ.get("LOCAL_RANK", None)) | |
distributed_rank = int(os.environ.get("RANK", None)) | |
assert ( | |
local_rank is not None and distributed_rank is not None | |
), "Please the set the RANK and LOCAL_RANK environment variables." | |
return local_rank, distributed_rank | |
def print_cfg(cfg): | |
""" | |
Supports printing both Hydra DictConfig and also the AttrDict config | |
""" | |
logging.info("Training with config:") | |
logging.info(OmegaConf.to_yaml(cfg)) | |
def set_seeds(seed_value, max_epochs, dist_rank): | |
""" | |
Set the python random, numpy and torch seed for each gpu. Also set the CUDA | |
seeds if the CUDA is available. This ensures deterministic nature of the training. | |
""" | |
# Since in the pytorch sampler, we increment the seed by 1 for every epoch. | |
seed_value = (seed_value + dist_rank) * max_epochs | |
logging.info(f"MACHINE SEED: {seed_value}") | |
random.seed(seed_value) | |
np.random.seed(seed_value) | |
torch.manual_seed(seed_value) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed_value) | |
def makedir(dir_path): | |
""" | |
Create the directory if it does not exist. | |
""" | |
is_success = False | |
try: | |
if not g_pathmgr.exists(dir_path): | |
g_pathmgr.mkdirs(dir_path) | |
is_success = True | |
except BaseException: | |
logging.info(f"Error creating directory: {dir_path}") | |
return is_success | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def get_amp_type(amp_type: Optional[str] = None): | |
if amp_type is None: | |
return None | |
assert amp_type in ["bfloat16", "float16"], "Invalid Amp type." | |
if amp_type == "bfloat16": | |
return torch.bfloat16 | |
else: | |
return torch.float16 | |
def log_env_variables(): | |
env_keys = sorted(list(os.environ.keys())) | |
st = "" | |
for k in env_keys: | |
v = os.environ[k] | |
st += f"{k}={v}\n" | |
logging.info("Logging ENV_VARIABLES") | |
logging.info(st) | |
class AverageMeter: | |
"""Computes and stores the average and current value""" | |
def __init__(self, name, device, fmt=":f"): | |
self.name = name | |
self.fmt = fmt | |
self.device = device | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
self._allow_updates = True | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def __str__(self): | |
fmtstr = "{name}: {val" + self.fmt + "} ({avg" + self.fmt + "})" | |
return fmtstr.format(**self.__dict__) | |
class MemMeter: | |
"""Computes and stores the current, avg, and max of peak Mem usage per iteration""" | |
def __init__(self, name, device, fmt=":f"): | |
self.name = name | |
self.fmt = fmt | |
self.device = device | |
self.reset() | |
def reset(self): | |
self.val = 0 # Per iteration max usage | |
self.avg = 0 # Avg per iteration max usage | |
self.peak = 0 # Peak usage for lifetime of program | |
self.sum = 0 | |
self.count = 0 | |
self._allow_updates = True | |
def update(self, n=1, reset_peak_usage=True): | |
self.val = torch.cuda.max_memory_allocated() // 1e9 | |
self.sum += self.val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
self.peak = max(self.peak, self.val) | |
if reset_peak_usage: | |
torch.cuda.reset_peak_memory_stats() | |
def __str__(self): | |
fmtstr = ( | |
"{name}: {val" | |
+ self.fmt | |
+ "} ({avg" | |
+ self.fmt | |
+ "}/{peak" | |
+ self.fmt | |
+ "})" | |
) | |
return fmtstr.format(**self.__dict__) | |
def human_readable_time(time_seconds): | |
time = int(time_seconds) | |
minutes, seconds = divmod(time, 60) | |
hours, minutes = divmod(minutes, 60) | |
days, hours = divmod(hours, 24) | |
return f"{days:02}d {hours:02}h {minutes:02}m" | |
class DurationMeter: | |
def __init__(self, name, device, fmt=":f"): | |
self.name = name | |
self.device = device | |
self.fmt = fmt | |
self.val = 0 | |
def reset(self): | |
self.val = 0 | |
def update(self, val): | |
self.val = val | |
def add(self, val): | |
self.val += val | |
def __str__(self): | |
return f"{self.name}: {human_readable_time(self.val)}" | |
class ProgressMeter: | |
def __init__(self, num_batches, meters, real_meters, prefix=""): | |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | |
self.meters = meters | |
self.real_meters = real_meters | |
self.prefix = prefix | |
def display(self, batch, enable_print=False): | |
entries = [self.prefix + self.batch_fmtstr.format(batch)] | |
entries += [str(meter) for meter in self.meters] | |
entries += [ | |
" | ".join( | |
[ | |
f"{os.path.join(name, subname)}: {val:.4f}" | |
for subname, val in meter.compute().items() | |
] | |
) | |
for name, meter in self.real_meters.items() | |
] | |
logging.info(" | ".join(entries)) | |
if enable_print: | |
print(" | ".join(entries)) | |
def _get_batch_fmtstr(self, num_batches): | |
num_digits = len(str(num_batches // 1)) | |
fmt = "{:" + str(num_digits) + "d}" | |
return "[" + fmt + "/" + fmt.format(num_batches) + "]" | |
def get_resume_checkpoint(checkpoint_save_dir): | |
if not g_pathmgr.isdir(checkpoint_save_dir): | |
return None | |
ckpt_file = os.path.join(checkpoint_save_dir, "checkpoint.pt") | |
if not g_pathmgr.isfile(ckpt_file): | |
return None | |
return ckpt_file | |