Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates | |
import dataclasses | |
import io | |
import logging | |
import operator | |
from collections import ChainMap | |
from functools import reduce | |
from typing import Any, cast, Dict, List, Optional, Tuple, Union | |
import torch | |
from torch.distributed._shard._utils import narrow_tensor_by_index | |
from torch.distributed._tensor import DTensor | |
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans | |
from torch.distributed.checkpoint._nested_dict import ( | |
FLATTEN_MAPPING, | |
flatten_state_dict, | |
) | |
from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors | |
from torch.distributed.checkpoint._traverse import set_element | |
from torch.distributed.checkpoint.metadata import ( | |
BytesStorageMetadata, | |
ChunkStorageMetadata, | |
Metadata, | |
MetadataIndex, | |
STATE_DICT_TYPE, | |
STORAGE_TYPES, | |
TensorStorageMetadata, | |
) | |
from torch.distributed.checkpoint.planner import ( | |
LoadPlan, | |
LoadPlanner, | |
ReadItem, | |
SavePlan, | |
SavePlanner, | |
WriteItem, | |
WriteItemType, | |
) | |
from torch.distributed.checkpoint.planner_helpers import ( | |
_create_default_metadata_only_plan, | |
_create_read_items, | |
_create_write_items, | |
_init_state_dict, | |
) | |
from torch.distributed.checkpoint.utils import find_state_dict_object | |
logger: logging.Logger = logging.getLogger(__name__) | |
__all__ = [ | |
"DefaultSavePlanner", | |
"DefaultLoadPlanner", | |
"create_default_local_load_plan", | |
"create_default_global_load_plan", | |
"create_default_local_save_plan", | |
"create_default_global_save_plan", | |
] | |
# TODO: Update docstrings for default_planner.py | |
class DefaultSavePlanner(SavePlanner): | |
mappings: FLATTEN_MAPPING | |
def __init__( | |
self, | |
flatten_state_dict: bool = True, | |
flatten_sharded_tensors: bool = True, | |
dedup_replicated_tensors: Optional[bool] = None, | |
) -> None: | |
self.flatten_state_dict = flatten_state_dict | |
self.flatten_sharded_tensors = flatten_sharded_tensors | |
self.mappings = {} | |
if dedup_replicated_tensors is not None: | |
logger.warning( | |
"DefaultSavePlanner's `dedup_replicated_tensors` argument is being " | |
"deprecated, and no longer has any effect. Please remove this argument " | |
"from your call." | |
) | |
def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None: | |
if self.flatten_state_dict: | |
state_dict, self.mappings = flatten_state_dict(state_dict) | |
if self.flatten_sharded_tensors: | |
state_dict = _flatten_sharded_tensors(state_dict) | |
self.state_dict = state_dict | |
self.is_coordinator = is_coordinator | |
def create_local_plan(self) -> SavePlan: | |
plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) | |
if self.flatten_state_dict: | |
plan = dataclasses.replace(plan, planner_data=self.mappings) | |
self.plan = plan | |
return self.plan | |
def create_global_plan( | |
self, all_plans: List[SavePlan] | |
) -> Tuple[List[SavePlan], Metadata]: | |
all_plans = dedup_save_plans(all_plans) | |
global_plan, metadata = create_default_global_save_plan(all_plans) | |
if self.flatten_state_dict: | |
# | does not work for Python 3.8 or older version. | |
# merged_mappings = reduce( | |
# lambda x, y: x | y, (p.planner_data for p in global_plan) | |
# ) | |
planner_data_dict = [p.planner_data for p in global_plan] | |
merged_mappings = dict(ChainMap(*planner_data_dict)) | |
metadata = dataclasses.replace(metadata, planner_data=merged_mappings) | |
if not _validate_global_plan(global_plan, metadata): | |
raise ValueError("Failed to validate global plan") | |
self.global_plan = global_plan | |
self.metadata = metadata | |
return self.global_plan, self.metadata | |
def finish_plan(self, new_plan: SavePlan) -> SavePlan: | |
self.plan = new_plan | |
return new_plan | |
def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: | |
object = self.lookup_object(write_item.index) | |
return self.transform_object(write_item, object) | |
def lookup_object(self, index: MetadataIndex) -> Any: | |
"""Extension from the planner interface to make it easy to extend the default planner.""" | |
return find_state_dict_object(self.state_dict, index) | |
def transform_object(self, write_item: WriteItem, object: Any): | |
"""Extension from the planner interface to make it easy to extend the default planner.""" | |
if write_item.type == WriteItemType.BYTE_IO: | |
bytes = io.BytesIO() | |
torch.save(object, bytes) | |
object = bytes | |
return object | |
class DefaultLoadPlanner(LoadPlanner): | |
""" | |
DefaultLoadPlanner that adds multiple features on top of LoadPlanner. | |
In particular it adds the following: | |
flatten_state_dict: Handle state_dict with nested dicts | |
flatten_sharded_tensors: For FSDP in 2D parallel mode | |
""" | |
original_state_dict: STATE_DICT_TYPE | |
mappings: FLATTEN_MAPPING | |
def __init__( | |
self, | |
flatten_state_dict: bool = True, | |
flatten_sharded_tensors: bool = True, | |
) -> None: | |
self.flatten_state_dict = flatten_state_dict | |
self.flatten_sharded_tensors = flatten_sharded_tensors | |
self.original_state_dict = {} | |
self.mappings = {} | |
def set_up_planner( | |
self, | |
state_dict: STATE_DICT_TYPE, | |
metadata: Metadata, | |
is_coordinator: bool, | |
) -> None: | |
_init_state_dict(state_dict) | |
self.original_state_dict = state_dict | |
if self.flatten_sharded_tensors: | |
state_dict = _flatten_sharded_tensors(state_dict) | |
if self.flatten_state_dict: | |
state_dict, self.mappings = flatten_state_dict(state_dict) | |
self.state_dict = state_dict | |
self.metadata = metadata | |
self.is_coordinator = is_coordinator | |
def create_local_plan(self) -> LoadPlan: | |
return create_default_local_load_plan(self.state_dict, self.metadata) | |
def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: | |
return create_default_global_load_plan(global_plan) | |
def finish_plan(self, new_plan: LoadPlan) -> LoadPlan: | |
return new_plan | |
def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None: | |
if self.flatten_state_dict: | |
set_element( | |
self.original_state_dict, | |
self.mappings[read_item.dest_index.fqn], | |
torch.load(value), | |
) | |
else: | |
self.state_dict[read_item.dest_index.fqn] = torch.load(value) | |
def resolve_tensor(self, read_item: ReadItem): | |
tensor = self.lookup_tensor(read_item.dest_index) | |
return self.transform_tensor(read_item, tensor) | |
def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None: | |
pass | |
def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: | |
"""Extension from the planner interface to make it easy to extend the default planner.""" | |
return find_state_dict_object(self.state_dict, index) | |
def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): | |
"""Extension from the planner interface to make it easy to extend the default planner.""" | |
return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths) | |
def create_default_local_load_plan( | |
state_dict: Dict[str, Any], | |
metadata: Metadata, | |
) -> LoadPlan: | |
requests = [] | |
""" | |
Create the ``LoadPlan`` used by DefaultLoadPlanner. | |
It produces one read item per value in ``state_dict`` using the metadata in ``metadata``. | |
The default behavior is to match key exactly between state_dict and metadata. | |
It handles resharding by issuing multiple read requests against storage in order to match | |
load requirements. | |
""" | |
for fqn, obj in state_dict.items(): | |
md = metadata.state_dict_metadata[fqn] | |
# Since DTensor supports submesh, adding extra check to ensure _create_read_items() | |
# gets called only when the current rank is part of the mesh for the corresponding DTensor. | |
if isinstance(obj, DTensor): | |
if obj.device_mesh.get_coordinate() is not None: | |
requests += _create_read_items(fqn, md, obj) | |
else: | |
requests += _create_read_items(fqn, md, obj) | |
return LoadPlan(requests) | |
def create_default_global_load_plan( | |
all_plans: List[LoadPlan], | |
) -> List[LoadPlan]: | |
""" | |
Create global load plan used by DefaultLoadPlanner. | |
The default load behavior involved no global coordination and this function | |
currently doesn't change the local plans. | |
""" | |
return all_plans | |
def create_default_local_save_plan( | |
state_dict: Dict[str, Any], is_coordinator: bool | |
) -> SavePlan: | |
""" | |
Create the ``SavePlan`` used by DefaultSavePlanner. | |
On non-coordinator ranks, this function ignores tensors and non-tensor objects, | |
only producing writes for ShardedTensor objects. | |
On the coordinator rank, produce writes for all values. | |
""" | |
requests = [] | |
for fqn, obj in state_dict.items(): | |
# Since DTensor supports submesh, adding extra check to ensure _create_write_items() | |
# gets called only when the current rank is part of the mesh for the corresponding DTensor. | |
if isinstance(obj, DTensor): | |
if obj.device_mesh.get_coordinate() is not None: | |
requests += _create_write_items(fqn, obj) | |
elif isinstance(obj, (torch.Tensor)) or is_coordinator: | |
requests += _create_write_items(fqn, obj) | |
return SavePlan(requests) | |
def create_default_global_save_plan( | |
all_plans: List[SavePlan], | |
rewrite_index_hints: bool = True, | |
) -> Tuple[List[SavePlan], Metadata]: | |
""" | |
Create the global plan and metadata used by DefaultSavePlanner. | |
Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans. | |
The only global planning change is to update index hints in all ``MetadataIndex`` objects if | |
``rewrite_index_hints`` is True. | |
""" | |
md: Dict[str, STORAGE_TYPES] = {} | |
new_plans = [] | |
for plan in all_plans: | |
new_items = [] | |
for item in plan.items: | |
if not item.type == WriteItemType.SHARD: | |
assert item.index.fqn not in md | |
if item.type == WriteItemType.BYTE_IO: | |
md[item.index.fqn] = BytesStorageMetadata() | |
new_items.append(item) | |
else: | |
assert item.tensor_data is not None | |
tensor_md = cast( | |
TensorStorageMetadata, | |
md.setdefault( | |
item.index.fqn, | |
TensorStorageMetadata( | |
properties=item.tensor_data.properties, | |
size=item.tensor_data.size, | |
chunks=[], | |
), | |
), | |
) | |
new_item = item | |
if rewrite_index_hints: | |
new_index = dataclasses.replace( | |
item.index, index=len(tensor_md.chunks) | |
) | |
new_item = dataclasses.replace(item, index=new_index) | |
new_items.append(new_item) | |
assert ( | |
item.tensor_data.chunk is not None | |
), f""" | |
Cannot create MD for tensor without bounds. | |
FQN: {item.index.fqn} | |
""" | |
tensor_md.chunks.append(item.tensor_data.chunk) | |
new_plans.append(dataclasses.replace(plan, items=new_items)) | |
return (new_plans, Metadata(md)) | |
def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata: | |
"""Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``.""" | |
plan = _create_default_metadata_only_plan(state_dict) | |
_, md = create_default_global_save_plan([plan]) | |
return md | |
def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool: | |
"""Check if two boxes overlap. Tuples are (offset, lengths).""" | |
# For each dim of each shard, check if one shard resides on the other | |
# end of second shard with respect to that dim. As an example for a 2D | |
# shard, we would check if one shard is above or on the left of the | |
# other shard. | |
ndims = len(box0.offsets) | |
for i in range(ndims): | |
if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]: | |
return False | |
if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]: | |
return False | |
return True | |
def _check_box_bounds( | |
outer_box_size: torch.Size, inner_box: ChunkStorageMetadata | |
) -> bool: | |
for i in range(len(outer_box_size)): | |
if inner_box.offsets[i] < 0: | |
return False | |
if inner_box.sizes[i] < 0: | |
return False | |
if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]: | |
return False | |
return True | |
def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool: | |
all_good = True | |
for key, value in metadata.state_dict_metadata.items(): | |
if isinstance(value, BytesStorageMetadata): | |
continue | |
if len(value.size) == 0: | |
continue | |
chunks_volume = 0 | |
for chunk_idx, chunk0 in enumerate(value.chunks): | |
# Compute the volume | |
if not _check_box_bounds(value.size, chunk0): | |
logger.warning( | |
""" | |
key:%s has out of bounds chunk: | |
tensor-size:%s chunk: %s | |
""", | |
key, | |
value.size, | |
chunk0, | |
) | |
all_good = False | |
chunks_volume += reduce(operator.mul, chunk0.sizes, 1) | |
# Check for overlap | |
for chunk1 in value.chunks[chunk_idx + 1 :]: | |
if _check_box_overlap(chunk0, chunk1): | |
logger.warning( | |
"key:%s has overlapping chunks: %s %s", key, chunk0, chunk1 | |
) | |
all_good = False | |
# Check whether combined chunk cover the whole tensor | |
tensor_volume = reduce(operator.mul, value.size, 1) | |
if chunks_volume != tensor_volume: | |
logger.warning( | |
""" | |
key:%s invalid fill tensor-volume: | |
%s chunks-volume: %s | |
""", | |
key, | |
tensor_volume, | |
chunks_volume, | |
) | |
all_good = False | |
return all_good | |