from typing import Optional import torch.distributed as dist def get_rank(group: Optional[dist.ProcessGroup] = None): return dist.get_rank(group) if dist.is_initialized() else 0 def get_world_size(group: Optional[dist.ProcessGroup] = None): return dist.get_world_size(group) if dist.is_initialized() else 1 def barrier(group: Optional[dist.ProcessGroup] = None): if dist.is_initialized(): dist.barrier(group) class rank_gate: ''' Execute the function on rank 0 first, followed by all other ranks. Useful when caches may need to be populated in a distributed environment. ''' def __init__(self, func = None): self.func = func def __call__(self, *args, **kwargs): rank = get_rank() if rank == 0: result = self.func(*args, **kwargs) barrier() if rank > 0: result = self.func(*args, **kwargs) return result def __enter__(self, *args, **kwargs): if get_rank() > 0: barrier() def __exit__(self, *args, **kwargs): if get_rank() == 0: barrier()