Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates | |
import dataclasses | |
from typing import cast, Dict, List, Optional, Sequence, Tuple, Union | |
import torch | |
import torch.distributed as dist | |
from torch._utils import _get_device_module | |
from torch.distributed._shard.sharded_tensor.api import ShardedTensor | |
from torch.distributed._shard.sharded_tensor.metadata import ( | |
TensorProperties as ShardTensorProperties, | |
) | |
from torch.distributed._shard.sharded_tensor.shard import Shard | |
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec | |
from torch.distributed._tensor import DTensor | |
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict | |
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner | |
from torch.distributed.checkpoint.metadata import ( | |
BytesStorageMetadata, | |
ChunkStorageMetadata, | |
Metadata, | |
MetadataIndex, | |
STATE_DICT_TYPE, | |
TensorProperties, | |
TensorStorageMetadata, | |
) | |
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner | |
from torch.distributed.checkpoint.planner_helpers import ( | |
_create_read_items, | |
create_read_items_for_chunk_list, | |
) | |
from torch.distributed.checkpoint.state_dict_loader import load_state_dict | |
from torch.distributed.checkpoint.storage import StorageReader | |
from torch.distributed.checkpoint.utils import ( | |
_element_wise_add, | |
_element_wise_sub, | |
_normalize_device_info, | |
) | |
from torch.distributed.distributed_c10d import _get_default_group | |
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor | |
from torch.distributed.remote_device import _remote_device | |
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]] | |
# TODO: Update docstrings for optimizer.py | |
__all__ = [ | |
"load_sharded_optimizer_state_dict", | |
] | |
def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str: | |
if device_type == "cpu": | |
return "cpu" | |
device_module = _get_device_module(device_type) | |
if device_module.is_available(): | |
return _normalize_device_info( | |
device_type, global_rank % device_module.device_count() | |
) | |
return "cpu" | |
def _create_colwise_spec( | |
pg: Optional[dist.ProcessGroup] = None, | |
) -> ChunkShardingSpec: | |
pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type | |
if pg is None: | |
placements = [ | |
f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}" | |
for idx in range(dist.get_world_size()) | |
] | |
else: | |
placements = [ | |
f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}" | |
for idx in range(pg.size()) | |
] | |
return ChunkShardingSpec( | |
dim=0, | |
placements=cast(List[Union[_remote_device, str]], placements), | |
) | |
def _is_nested_tensor(val: torch.Tensor) -> bool: | |
if type(val) is ShardedTensor: | |
if len(val.local_shards()) == 0: | |
return False | |
if type(val.local_shards()[0].tensor) is ShardedTensor: | |
return True | |
if type(val.local_shards()[0].tensor) is DTensor: | |
raise ValueError("Cannot handle DTensor nested insided ShardedTensor") | |
elif type(val) is DTensor and ( | |
type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor | |
): | |
raise ValueError("Cannot handle nested DTensor") | |
return False | |
def _alloc_tensor( | |
props: TensorProperties, size: Sequence[int], device_type: str = "cuda" | |
) -> torch.Tensor: | |
return torch.empty( | |
size=size, | |
dtype=props.dtype, | |
layout=props.layout, | |
requires_grad=props.requires_grad, | |
pin_memory=props.pin_memory, | |
device=cast(torch.device, _get_device_module(device_type).current_device()), | |
) | |
def _get_state_dict_2d_layout( | |
state_dict: STATE_DICT_TYPE, | |
) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]: | |
""" | |
Load the right TP slice of the optimizer state. | |
This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata. | |
We take advantage of the model state_dict producing a sliced ST to figure out what we need to load. | |
This is pretty fragile and it might be easier for FSDP to compute this info for us. | |
Returns a dictionary where keys are the same of the state_dict and the value is a tuple of | |
(offset, size) for the current rank TP slice. | |
N.B. The state_dict *MUST* come from FSDP.sharded_state_dict. | |
""" | |
specs: STATE_DICT_2D_LAYOUT = {} | |
dp_pg: Optional[dist.ProcessGroup] = None | |
for key, value in state_dict.items(): | |
specs[key] = (None, value.size()) | |
if _is_nested_tensor(value): | |
assert ( | |
len(value.local_shards()) == 1 | |
), "Cannot handle ST with multiple shards" | |
assert isinstance( | |
value, ShardedTensor | |
), "Can only handle nested ShardedTensor" | |
shard = value.local_shards()[0] | |
specs[key] = ( | |
shard.metadata.shard_offsets, | |
shard.metadata.shard_sizes, | |
) | |
dp_pg = shard.tensor._process_group # type: ignore[attr-defined] | |
return ( | |
specs, | |
dp_pg, | |
) | |
class _ReaderWithOffset(DefaultLoadPlanner): | |
translation: Dict[MetadataIndex, MetadataIndex] | |
state_dict: STATE_DICT_TYPE | |
metadata: Metadata | |
def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None: | |
super().__init__() | |
self.fqn_to_offset = fqn_to_offset | |
self.metadata = Metadata({}) | |
self.state_dict = {} | |
self.translation = {} | |
def create_local_plan(self) -> LoadPlan: | |
requests = [] | |
self.translation = {} | |
for fqn, obj in self.state_dict.items(): | |
md = self.metadata.state_dict_metadata[fqn] | |
if not isinstance(obj, ShardedTensor): | |
requests += _create_read_items(fqn, md, obj) | |
continue | |
if fqn not in self.fqn_to_offset: | |
requests += _create_read_items(fqn, md, obj) | |
continue | |
offset = self.fqn_to_offset[fqn] | |
assert len(obj.local_shards()) == 1 | |
original_shard = obj.local_shards()[0] | |
local_chunks = [ | |
ChunkStorageMetadata( | |
offsets=torch.Size( | |
_element_wise_add(original_shard.metadata.shard_offsets, offset) | |
), | |
sizes=torch.Size(original_shard.metadata.shard_sizes), | |
) | |
] | |
reqs = create_read_items_for_chunk_list( | |
fqn, cast(TensorStorageMetadata, md), local_chunks | |
) | |
# TODO: The ReadItems will have a displaced MetadataIndex, fix it. | |
# TODO: we should change _create_sharded_read_items to have more ergonomic API | |
for ri in reqs: | |
assert ri.dest_index.offset is not None | |
original_offset = _element_wise_sub(ri.dest_index.offset, offset) | |
original_index = dataclasses.replace( | |
ri.dest_index, offset=torch.Size(original_offset) | |
) | |
self.translation[ri.dest_index] = original_index | |
requests += reqs | |
return LoadPlan(requests) | |
def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: | |
return super().lookup_tensor(self.translation.get(index, index)) | |
def load_sharded_optimizer_state_dict( | |
model_state_dict: STATE_DICT_TYPE, | |
optimizer_key: str, | |
storage_reader: StorageReader, | |
planner: Optional[LoadPlanner] = None, | |
) -> STATE_DICT_TYPE: | |
""" | |
Load a state_dict in conjunction with FSDP sharded optimizer state. | |
This is the current recommended way to checkpoint FSDP. | |
>>> # xdoctest: +SKIP | |
>>> import torch.distributed.checkpoint as dist_cp | |
>>> # Save | |
>>> model: torch.nn.Model | |
>>> optim_params = model.parameters() | |
>>> optim = torch.optim.SGD(optim_params, lr=0.01) | |
>>> # Save | |
>>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): | |
>>> state_dict = { | |
>>> "optimizer": FSDP.optim_state_dict(model, optim), | |
>>> "model": model.state_dict() | |
>>> } | |
>>> dist_cp.save_state_dict( | |
>>> state_dict=optim_state, | |
>>> storage_writer=dist_cp.FileSystemWriter("checkpoint"), | |
>>> planner=dist_cp.DefaultSavePlanner(), | |
>>> ) | |
>>> | |
>>> # Load | |
>>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): | |
>>> model_state_dict = model_tp.state_dict() | |
>>> checkpoint = { | |
>>> "model": model_state_dict | |
>>> } | |
>>> dist_cp.load_state_dict( | |
>>> state_dict=checkpoint, | |
>>> storage_reader=dist_cp.FileSystemReader(checkpoint_file), | |
>>> planner=dist_cp.DefaultLoadPlanner(), | |
>>> ) | |
>>> model.load_state_dict(checkpoint["model_state"]) | |
>>> | |
>>> optim_state = dist_cp.load_sharded_optimizer_state_dict( | |
>>> model_state_dict, | |
>>> optimizer_key="optimizer", | |
>>> storage_reader=dist_cp.FileSystemReader("checkpoint"), | |
>>> ) | |
>>> | |
>>> flattened_osd = FSDP.optim_state_dict_to_load( | |
>>> model, optim, optim_state["optimizer"] | |
>>> ) | |
>>> | |
>>> optim.load_state_dict(flattened_osd) | |
""" | |
metadata = storage_reader.read_metadata() | |
layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict) | |
dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type | |
device_module = _get_device_module(dp_pg_device_type) | |
if dp_pg is None: | |
placements = [] | |
for i in range(dist.get_world_size()): | |
device_info = _normalize_device_info( | |
dp_pg_device_type, i % device_module.device_count() | |
) | |
placements.append(f"rank:{i}/{device_info}") | |
sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type] | |
else: | |
sharding_spec = _create_colwise_spec(dp_pg) | |
# Create a state_dict for optimizer state | |
state_dict: STATE_DICT_TYPE = {} | |
fqn_to_offset: Dict[str, Sequence[int]] = {} | |
for key, value in metadata.state_dict_metadata.items(): | |
key_path = metadata.planner_data[key] | |
if key_path[0] != optimizer_key: | |
continue | |
if isinstance(value, BytesStorageMetadata): | |
state_dict[key] = "<bytes_io>" | |
continue | |
# value: TensorStorageMetadata | |
if value.size.numel() == 1: | |
state_dict[key] = _alloc_tensor( | |
value.properties, value.size, dp_pg_device_type | |
) | |
elif dp_pg is None: | |
state_dict[key] = _create_chunk_sharded_tensor( | |
_alloc_tensor(value.properties, value.size, dp_pg_device_type), | |
rank=dist.get_rank(), | |
world_size=dist.get_world_size(), | |
num_devices_per_node=device_module.device_count(), | |
pg=_get_default_group(), | |
) | |
else: | |
spec_key = key_path[2] | |
alloc_size = layout_specs.get(spec_key, (None, value.size))[1] | |
properties = ShardTensorProperties( | |
dtype=value.properties.dtype, | |
layout=value.properties.layout, | |
requires_grad=value.properties.requires_grad, | |
memory_format=value.properties.memory_format, | |
pin_memory=value.properties.pin_memory, | |
) | |
st_md = sharding_spec.build_metadata(torch.Size(alloc_size), properties) | |
local_shards = [] | |
current_rank = dist.get_rank(dp_pg) | |
for shard_md in st_md.shards_metadata: | |
if cast(_remote_device, shard_md.placement).rank() != current_rank: | |
continue | |
local_shards.append( | |
Shard( | |
tensor=_alloc_tensor( | |
value.properties, shard_md.shard_sizes, dp_pg_device_type | |
), | |
metadata=shard_md, | |
) | |
) | |
st = ShardedTensor._init_from_local_shards_and_global_metadata( | |
local_shards, st_md, process_group=dp_pg | |
) | |
if spec_key in layout_specs and layout_specs[spec_key][0] is not None: | |
fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0]) | |
state_dict[key] = st | |
# Whether we unflatten before or after doesn't matter | |
load_state_dict( | |
state_dict=state_dict, | |
storage_reader=storage_reader, | |
# FIXME the type of planner is wrong in load_state_dict | |
planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else planner, | |
) | |
state_dict = unflatten_state_dict(state_dict, metadata.planner_data) | |
return state_dict | |