Spaces:
Running
Running
from typing import Any, cast, List | |
import torch | |
import torch.distributed as dist | |
from torch._utils import _get_device_module | |
from torch.distributed._shard.metadata import ShardMetadata | |
from torch.distributed._shard.sharded_tensor import ShardedTensor | |
from torch.distributed._tensor import DTensor | |
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset | |
from torch.utils._pytree import tree_map_only | |
from .metadata import ( | |
BytesStorageMetadata, | |
ChunkStorageMetadata, | |
MetadataIndex, | |
STATE_DICT_TYPE, | |
STORAGE_TYPES, | |
TensorProperties, | |
TensorStorageMetadata, | |
) | |
from .planner import ( | |
LoadItemType, | |
ReadItem, | |
SavePlan, | |
TensorWriteData, | |
WriteItem, | |
WriteItemType, | |
) | |
from .resharding import ( | |
_check_shard_metadata_pair_overlap, | |
_shards_get_overlap_region_wrt_saved_tensor, | |
) | |
__all__: List[str] = ["create_read_items_for_chunk_list"] | |
def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata: | |
return ChunkStorageMetadata( | |
offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size() | |
) | |
def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: | |
return ChunkStorageMetadata( | |
offsets=torch.Size(shard_md.shard_offsets), | |
sizes=torch.Size(shard_md.shard_sizes), | |
) | |
def _sharded_tensor_metadata( | |
sharded_tensor: ShardedTensor, shard_md: ShardMetadata | |
) -> TensorWriteData: | |
shard_properties = sharded_tensor.metadata().tensor_properties | |
properties = TensorProperties( | |
dtype=shard_properties.dtype, | |
layout=shard_properties.layout, | |
requires_grad=shard_properties.requires_grad, | |
memory_format=shard_properties.memory_format, | |
pin_memory=shard_properties.pin_memory, | |
) | |
return TensorWriteData( | |
chunk=_chunk_for_shard(shard_md), | |
properties=properties, | |
size=sharded_tensor.metadata().size, | |
) | |
def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem: | |
sizes, offsets = compute_local_shape_and_global_offset( | |
tensor.shape, tensor.device_mesh, tensor.placements | |
) | |
sizes, offsets = torch.Size(sizes), torch.Size(offsets) | |
return WriteItem( | |
index=MetadataIndex(fqn, offsets), | |
type=WriteItemType.SHARD, | |
tensor_data=TensorWriteData( | |
chunk=ChunkStorageMetadata( | |
offsets=offsets, | |
sizes=sizes, | |
), | |
properties=TensorProperties.create_from_tensor(tensor.to_local()), | |
size=tensor.size(), | |
), | |
) | |
def _create_write_item_for_shard( | |
fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata | |
) -> WriteItem: | |
offsets = torch.Size(shard_md.shard_offsets) | |
return WriteItem( | |
index=MetadataIndex(fqn, offsets), | |
type=WriteItemType.SHARD, | |
tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), | |
) | |
def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: | |
offsets = torch.Size([0] * len(tensor.size())) | |
return WriteItem( | |
index=MetadataIndex(fqn, offsets), | |
type=WriteItemType.TENSOR, | |
tensor_data=TensorWriteData( | |
chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), | |
properties=TensorProperties.create_from_tensor(tensor), | |
size=tensor.size(), | |
), | |
) | |
def _create_write_item_for_bytesio(fqn: str, bytes: Any): | |
return WriteItem( | |
index=MetadataIndex(fqn), | |
type=WriteItemType.BYTE_IO, | |
) | |
def _create_read_item_for_byteio( | |
dest_index, dest_offset, storage_index, storage_offset, length | |
): | |
return ReadItem( | |
type=LoadItemType.BYTE_IO, | |
dest_index=dest_index, | |
dest_offsets=torch.Size((dest_offset,)), | |
storage_index=storage_index, | |
storage_offsets=torch.Size((storage_offset,)), | |
lengths=torch.Size((length,)), | |
) | |
def _create_read_item_for_tensor( | |
dest_index, dest_offsets, storage_index, storage_offsets, lengths | |
): | |
return ReadItem( | |
type=LoadItemType.TENSOR, | |
dest_index=dest_index, | |
dest_offsets=torch.Size(dest_offsets), | |
storage_index=storage_index, | |
storage_offsets=torch.Size(storage_offsets), | |
lengths=torch.Size(lengths), | |
) | |
def create_read_items_for_chunk_list( | |
fqn: str, | |
checkpoint_md: TensorStorageMetadata, | |
local_chunks: List[ChunkStorageMetadata], | |
) -> List[ReadItem]: | |
""" | |
Create a list of ``ReadItem`` based on the checkpoint and local chunks. | |
This applies the resharding algorithm and computes the reads needed | |
to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``. | |
Args: | |
fqn (str) : The state_dict FQN to pass to ``ReadItem``. | |
checkpoint_md (TensorStorageMetadata): metadata for a given tensor | |
from a checkpoint. | |
local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be | |
loaded. | |
Returns: | |
A list of ``ReadItem`` that will satisfy all input chunks. | |
""" | |
read_items = [] | |
# this is a naive quadratic algo that can be optimized later | |
for idx, shard in enumerate(local_chunks): | |
for storage_idx, storage_md in enumerate(checkpoint_md.chunks): | |
if not _check_shard_metadata_pair_overlap(shard, storage_md): | |
continue | |
storage_offsets = [] | |
dest_offsets = [] | |
lengths = [] | |
for ( | |
dim, | |
offset_for_saved_tensor, | |
offset_for_current_tensor, | |
length, | |
) in _shards_get_overlap_region_wrt_saved_tensor( | |
saved_shard=storage_md, current_shard=shard | |
): | |
storage_offsets.append(offset_for_saved_tensor) | |
dest_offsets.append(offset_for_current_tensor) | |
lengths.append(length) | |
read_items.append( | |
_create_read_item_for_tensor( | |
dest_index=MetadataIndex(fqn, shard.offsets, idx), | |
dest_offsets=dest_offsets, | |
storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx), | |
storage_offsets=storage_offsets, | |
lengths=lengths, | |
) | |
) | |
return read_items | |
def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: | |
requests = [] | |
for fqn, obj in state_dict.items(): | |
if isinstance(obj, DTensor): | |
requests.append(_create_write_items_for_dtensor(fqn, obj)) | |
elif isinstance(obj, ShardedTensor): | |
for shard_md in obj.metadata().shards_metadata: | |
requests.append(_create_write_item_for_shard(fqn, obj, shard_md)) | |
elif isinstance(obj, torch.Tensor): | |
requests.append(_create_write_item_for_tensor(fqn, obj)) | |
else: | |
requests.append(_create_write_item_for_bytesio(fqn, obj)) | |
return SavePlan(requests) | |
def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: | |
if isinstance(object, DTensor): | |
return [_create_write_items_for_dtensor(fqn, object)] | |
elif isinstance(object, ShardedTensor): | |
return [ | |
_create_write_item_for_shard(fqn, object, shard.metadata) | |
for shard in object.local_shards() | |
] | |
elif isinstance(object, torch.Tensor): | |
return [_create_write_item_for_tensor(fqn, object)] | |
else: | |
return [_create_write_item_for_bytesio(fqn, object)] | |
def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: | |
sizes, offsets = compute_local_shape_and_global_offset( | |
tensor.shape, tensor.device_mesh, tensor.placements | |
) | |
sizes, offsets = torch.Size(sizes), torch.Size(offsets) | |
return ChunkStorageMetadata( | |
offsets=offsets, | |
sizes=sizes, | |
) | |
def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]: | |
if isinstance(tensor, DTensor): | |
local_chunks = [_create_chunk_from_dtensor(tensor)] | |
elif isinstance(tensor, ShardedTensor): | |
local_chunks = [ | |
_chunk_for_shard(shard.metadata) for shard in tensor.local_shards() | |
] | |
elif isinstance(tensor, torch.Tensor): | |
local_chunks = [_create_chunk_from_tensor(tensor)] | |
else: | |
raise ValueError( | |
"Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] " | |
f",but got {type(tensor)}" | |
) | |
return local_chunks | |
def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: | |
if not isinstance(md, BytesStorageMetadata): | |
try: | |
local_chunks = _create_chunk_list(obj) | |
except ValueError as ex: | |
raise ValueError( | |
f"Invalid checkpoint metadata for {fqn}, " | |
+ f"expected BytesStorageMetadata but found {type(md)}", | |
) from ex | |
return create_read_items_for_chunk_list(fqn, md, local_chunks) | |
else: | |
return [ | |
_create_read_item_for_byteio( | |
dest_index=MetadataIndex(fqn), | |
dest_offset=0, | |
storage_index=MetadataIndex(fqn), | |
storage_offset=0, | |
length=0, | |
) | |
] | |
def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None: | |
state_dict_assigned_storage = tree_map_only( | |
torch.Tensor, lambda v: _init_meta_tensor(v), state_dict | |
) | |
# The inplace version of tree_map_only, tree_map_only_ doesn't seem to work. | |
# So we need to temporariy update the each element in the state dict with meta tensor. | |
for k in state_dict.keys(): | |
state_dict[k] = state_dict_assigned_storage[k] | |
def _init_meta_tensor(value: Any) -> Any: | |
""" | |
Initializes tensor, moves it to device for torch.Tensor/DTensor on meta device. | |
""" | |
device = getattr(value, "device", None) | |
# DCP does the initialization if it's meta tensor/DTensor. | |
if device == torch.device("meta"): | |
device_type = dist.distributed_c10d._get_pg_default_device().type | |
device = cast(torch.device, _get_device_module(device_type).current_device()) | |
if isinstance(value, DTensor): | |
new_local_tensor = torch.empty_like(value.to_local(), device=device) | |
# We need to pass shape and stride explicitly, since DTensor might be | |
# sharded unevenly. | |
dtensor = DTensor.from_local( | |
new_local_tensor, | |
device_mesh=value.device_mesh, | |
placements=value.placements, | |
shape=value.size(), | |
stride=value.stride(), | |
) | |
return dtensor | |
elif isinstance(value, torch.Tensor): | |
tensor = torch.empty_like(value, device=device) | |
return tensor | |
else: | |
raise RuntimeError( | |
f"Found unsupported type {type(value)} for meta device loading." | |
) | |
else: | |
return value | |