| """ | |
| utility helpers for distributed checks | |
| """ | |
| from contextlib import contextmanager | |
| import torch.distributed as dist | |
| from accelerate import Accelerator | |
| accelerate = None # pylint: disable=invalid-name | |
| def load_accelerate(): | |
| global accelerate # pylint: disable=global-statement | |
| accelerate = Accelerator() | |
| def is_distributed(): | |
| """ | |
| Check if distributed training is initialized. | |
| """ | |
| global accelerate # pylint: disable=global-statement | |
| if not accelerate: | |
| accelerate = Accelerator() | |
| return dist.is_available() and dist.is_initialized() | |
| def barrier(): | |
| """ | |
| Acts as a barrier to wait for all processes. This ensures that all processes | |
| reach the barrier before proceeding further. | |
| """ | |
| if is_distributed(): | |
| dist.barrier() | |
| def is_main_process(): | |
| """ | |
| Check if the current process is the main process. | |
| If not in distributed mode, always return True. | |
| """ | |
| if not is_distributed(): | |
| return True | |
| return dist.get_rank() == 0 | |
| def zero_first(is_main): | |
| """ | |
| runs the wrapped context so that rank 0 runs first before other ranks | |
| """ | |
| if not is_main: # other ranks wait first | |
| barrier() | |
| yield | |
| if is_main: # then rank 0 waits after it has run the context | |
| barrier() | |