Spaces:
Sleeping
Sleeping
from typing import List, Tuple | |
from torch.distributed.checkpoint.metadata import ChunkStorageMetadata | |
__all__: List[str] = [] | |
def _check_shard_metadata_pair_overlap( | |
shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata | |
): | |
"""Check if two shards overlap.""" | |
# For each dim of each shard, check if one shard resides on the other | |
# end of second shard with respect to that dim. As an example for a 2D | |
# shard, we would check if one shard is above or on the left of the | |
# other shard. | |
ndims = len(shard1.offsets) | |
for i in range(ndims): | |
if shard1.offsets[i] >= shard2.offsets[i] + shard2.sizes[i]: | |
return False | |
if shard2.offsets[i] >= shard1.offsets[i] + shard1.sizes[i]: | |
return False | |
return True | |
def _shards_get_overlap_region_wrt_saved_tensor( | |
saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata | |
) -> List[Tuple[int, int, int, int]]: | |
""" | |
Return the overlapping region between saved_shard and current_shard. | |
There returned list has the same number of elements as the tensor's dimension. | |
For each element, we produce a tuple with the following contents: | |
(dimension, `saved_shard` offset, `current_shard` offset, length) | |
Offsets are relative to each shard. | |
""" | |
narrows = [] | |
for dim, ( | |
saved_shard_offset, | |
current_shard_offset, | |
saved_shard_size, | |
current_shard_size, | |
) in enumerate( | |
zip( | |
saved_shard.offsets, | |
current_shard.offsets, | |
saved_shard.sizes, | |
current_shard.sizes, | |
) | |
): | |
min_range_end = min( | |
saved_shard_offset + saved_shard_size, | |
current_shard_offset + current_shard_size, | |
) | |
length = min_range_end - max(current_shard_offset, saved_shard_offset) | |
if saved_shard_offset > current_shard_offset: | |
offset_for_saved_tensor = 0 | |
offset_for_current_tensor = saved_shard_offset - current_shard_offset | |
else: | |
offset_for_saved_tensor = current_shard_offset - saved_shard_offset | |
offset_for_current_tensor = 0 | |
narrows.append( | |
(dim, offset_for_saved_tensor, offset_for_current_tensor, length) | |
) | |
return narrows | |