Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from typing import List | |
| from torch import distributed | |
| def barrier(): | |
| if distributed.is_initialized(): | |
| distributed.barrier() | |
| else: | |
| pass | |
| def broadcast(data, src): | |
| if distributed.is_initialized(): | |
| distributed.broadcast(data, src) | |
| else: | |
| pass | |
| def all_gather(data: List, src): | |
| if distributed.is_initialized(): | |
| distributed.all_gather(data, src) | |
| else: | |
| data[0] = src | |
| def get_rank(): | |
| if distributed.is_initialized(): | |
| return distributed.get_rank() | |
| else: | |
| return 0 | |
| def get_world_size(): | |
| if distributed.is_initialized(): | |
| return distributed.get_world_size() | |
| else: | |
| return 1 | |
| def chunk_size(size, rank, world_size): | |
| extra = rank < size % world_size | |
| return size // world_size + extra | 
