Spaces:
Sleeping
Sleeping
import cProfile | |
import inspect | |
import io | |
import itertools | |
import os | |
import warnings | |
from contextlib import contextmanager | |
from functools import wraps | |
from pstats import Stats | |
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union | |
import torch | |
import torch.distributed as dist | |
from torch.distributed._shard.sharded_tensor import ShardedTensor | |
from torch.distributed._shard.sharded_tensor.shard import Shard | |
from torch.distributed._tensor import DTensor | |
from .api import ( | |
_is_wrapped_exception, | |
_wrap_exception, | |
CheckpointException, | |
WRAPPED_EXCEPTION, | |
) | |
from .metadata import MetadataIndex, STATE_DICT_TYPE | |
__all__ = ["find_tensor_shard", "find_state_dict_object"] | |
T = TypeVar("T") | |
R = TypeVar("R") | |
def _get_failure_dict( | |
results: List[Union[T, WRAPPED_EXCEPTION]] | |
) -> Dict[int, WRAPPED_EXCEPTION]: | |
return cast( | |
Dict[int, WRAPPED_EXCEPTION], | |
{i: err for i, err in enumerate(results) if _is_wrapped_exception(err)}, | |
) | |
def _all_gather_keys( | |
local_dict: Dict[Any, Any], group: Optional[dist.ProcessGroup] = None | |
) -> List[Any]: | |
"""Gathers all keys, and returns them sorted.""" | |
keys = list(local_dict.keys()) | |
gathered_keys: List[List[Any]] = [None] * dist.get_world_size() # type: ignore[list-item] | |
dist.all_gather_object(gathered_keys, keys, group=group) | |
return sorted(set(itertools.chain.from_iterable(gathered_keys))) | |
class _DistWrapper: | |
""" | |
This is a wrapper around PG that provides a series of features around object collectives. | |
It works without distributed initialized, where most collectives turns into nops. | |
All variants that take functions are exception robust, meaning that if one or more | |
ranks raise errors, all ranks will observe those. | |
""" | |
def __init__( | |
self, | |
group: Optional[dist.ProcessGroup], | |
use_dist: bool, | |
coordinator_rank: int, | |
): | |
self.group = group | |
self.use_dist = use_dist | |
self.coordinator_rank = coordinator_rank | |
if self.use_dist: | |
self.rank = dist.get_rank(group) | |
self.is_coordinator = self.rank == coordinator_rank | |
else: | |
self.rank = 0 | |
self.is_coordinator = True | |
def get_rank(self) -> int: | |
return self.rank | |
def get_world_size(self) -> int: | |
if self.use_dist: | |
return dist.get_world_size(self.group) | |
return 1 | |
def broadcast_object(self, object: Optional[T]) -> T: | |
"""Implement functionality similar to c10d::broadcast_object_list but without distributed enabled.""" | |
object_list = [object] | |
if self.use_dist: | |
dist.broadcast_object_list( | |
object_list=object_list, | |
group=self.group, | |
src=self.coordinator_rank, | |
) | |
return cast(T, object_list[0]) | |
def gather_object(self, object: T) -> Optional[List[T]]: | |
"""Implement functionality similar to c10d::gather_object but without distributed enabled.""" | |
if self.use_dist: | |
gather_objs = ( | |
cast(List[T], [None] * dist.get_world_size(self.group)) | |
if self.is_coordinator | |
else None | |
) | |
dist.gather_object( | |
obj=object, | |
object_gather_list=gather_objs if self.is_coordinator else None, | |
dst=self.coordinator_rank, | |
group=self.group, | |
) | |
result = gather_objs | |
else: | |
result = [object] | |
return result | |
def all_gather_object(self, object: T) -> List[T]: | |
"""Implement functionality similar to c10d::all_gather_object but without distributed enabled.""" | |
if self.use_dist: | |
gather_objs = cast(List[T], [None] * dist.get_world_size(self.group)) | |
dist.all_gather_object( | |
object_list=gather_objs, obj=object, group=self.group | |
) | |
else: | |
gather_objs = [object] | |
return gather_objs | |
def scatter_object(self, object_list: Optional[List[T]]) -> T: | |
"""Implement functionality similar to c10d::scatter_object but without distributed enabled.""" | |
if self.use_dist: | |
gather_result = cast(List[T], [None]) | |
dist.scatter_object_list( | |
scatter_object_output_list=gather_result, | |
scatter_object_input_list=object_list if self.is_coordinator else None, | |
src=self.coordinator_rank, | |
group=self.group, | |
) | |
local_reply = gather_result[0] | |
else: | |
assert object_list is not None | |
local_reply = object_list[0] | |
return local_reply | |
def reduce_scatter( | |
self, | |
step: str, | |
map_fun: Callable[[], T], | |
reduce_fun: Callable[[List[T]], List[R]], | |
) -> R: | |
""" | |
Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. | |
This method operates in the following way: | |
Run ``map_fun`` on all ranks | |
Gather results on rank 0 | |
Call ``reduce_fun`` on all those values | |
Scatter to each rank part of the result. | |
""" | |
local_data: Union[WRAPPED_EXCEPTION, T] | |
try: | |
local_data = map_fun() | |
except BaseException as e: | |
local_data = _wrap_exception(e) | |
all_data = self.gather_object(local_data) | |
all_results: Optional[List[Union[R, CheckpointException]]] = None | |
if self.is_coordinator: | |
assert all_data is not None | |
node_failures = _get_failure_dict(all_data) | |
if len(node_failures) == 0: | |
try: | |
# N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]? | |
all_results = cast( | |
List[Union[R, CheckpointException]], | |
reduce_fun(cast(List[T], all_data)), | |
) | |
except BaseException as e: | |
node_failures[self.rank] = _wrap_exception(e) | |
if len(node_failures) > 0: | |
all_results = [ | |
CheckpointException(step, node_failures) | |
] * self.get_world_size() | |
result = self.scatter_object(all_results) | |
if isinstance(result, CheckpointException): | |
raise result | |
return result | |
def all_reduce( | |
self, | |
step: str, | |
map_fun: Callable[[], T], | |
reduce_fun: Callable[[List[T]], R], | |
) -> R: | |
""" | |
Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast. | |
This method operates in the following way: | |
Run ``map_fun`` on all ranks | |
Gather results on rank 0 | |
Call ``reduce_fun`` on all those values | |
Broadcast the reduced value to all ranks. | |
""" | |
local_data: Union[T, WRAPPED_EXCEPTION] | |
try: | |
local_data = map_fun() | |
except BaseException as e: | |
local_data = _wrap_exception(e) | |
all_data = self.gather_object(local_data) | |
result: Optional[Union[R, CheckpointException]] = None | |
if self.is_coordinator: | |
assert all_data is not None | |
node_failures = _get_failure_dict(all_data) | |
if len(node_failures) == 0: | |
try: | |
result = reduce_fun(cast(List[T], all_data)) | |
except BaseException as e: | |
node_failures[self.rank] = _wrap_exception(e) | |
if len(node_failures) > 0: | |
result = CheckpointException(step, node_failures) | |
final_result = self.broadcast_object(result) | |
if isinstance(final_result, CheckpointException): | |
raise final_result | |
return cast(R, final_result) | |
def all_gather( | |
self, | |
step: str, | |
map_fun: Callable[[], T], | |
) -> List[T]: | |
""" | |
Compute a value on each rank, then all_gather them. | |
This method operates in the following way: | |
Run ``map_cp`` on all ranks | |
all_gather the values to all ranks | |
""" | |
result: Union[T, WRAPPED_EXCEPTION] | |
try: | |
result = map_fun() | |
except BaseException as e: | |
result = _wrap_exception(e) | |
all_results = self.all_gather_object(result) | |
node_failures = _get_failure_dict(all_results) | |
if len(node_failures) > 0: | |
raise CheckpointException(step, node_failures) | |
return cast(List[T], all_results) | |
def broadcast( | |
self, | |
step: str, | |
map_fun: Callable[[], T], | |
) -> T: | |
""" | |
Compute a value on rank 0 and broadcast it. | |
This method operates in the following way: | |
Run ``map_cp`` on rank 0 | |
broadcast the value | |
""" | |
result: Optional[Union[T, CheckpointException]] = None | |
if self.is_coordinator: | |
try: | |
result = map_fun() | |
except BaseException as e: | |
result = CheckpointException(step, {self.rank: _wrap_exception(e)}) | |
final_result = self.broadcast_object(result) | |
if isinstance(final_result, CheckpointException): | |
raise final_result | |
return cast(T, final_result) | |
def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: | |
if index.offset is None: | |
raise ValueError( | |
f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided" | |
) | |
shards = tensor.local_shards() | |
# index fast path | |
if index.index is not None: | |
if ( | |
len(shards) > index.index | |
and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset | |
): | |
return shards[index.index] | |
for shard in shards: | |
if torch.Size(shard.metadata.shard_offsets) == index.offset: | |
return shard | |
raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'") | |
def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: | |
if isinstance(tensor, DTensor): | |
return tensor.to_local() | |
if isinstance(tensor, ShardedTensor): | |
return _find_shard(tensor, index).tensor | |
if index.offset is not None: | |
# special case looking up a tensor by origin | |
if index.offset == torch.Size([0] * len(tensor.size())): | |
return tensor | |
raise ValueError( | |
f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" | |
) | |
return tensor | |
def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any: | |
if index.fqn not in state_dict: | |
raise ValueError(f"Could not find FQN: '{index.fqn}'") | |
obj = state_dict[index.fqn] | |
if isinstance(obj, torch.Tensor): | |
return find_tensor_shard(obj, index) | |
elif index.offset is not None: | |
raise ValueError( | |
f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" | |
) | |
return obj | |
def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> List[int]: | |
return [i_a + i_b for i_a, i_b in zip(a, b)] | |
def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]: | |
return [i_a - i_b for i_a, i_b in zip(a, b)] | |
class _ReaderView(io.IOBase): | |
def __init__(self, base_stream: io.IOBase, offset: int, len: int): | |
super().__init__() | |
self.offset = offset | |
self.len = len | |
self.base_stream = base_stream | |
self.seek(0) | |
def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int: | |
if __whence == os.SEEK_SET: | |
__offset = self.offset + __offset | |
elif __whence == os.SEEK_END: | |
__whence = os.SEEK_SET | |
__offset = (self.offset + self.len) - __offset | |
return self.base_stream.seek(__offset, __whence) | |
def tell(self) -> int: | |
return self.base_stream.tell() - self.offset | |
def readable(self) -> bool: | |
return self.base_stream.readable() | |
def seekable(self) -> bool: | |
return self.base_stream.seekable() | |
def readinto(self, b): | |
return self.base_stream.readinto(b) # type: ignore[attr-defined] | |
def read(self, size=-1): | |
return self.base_stream.read(size) | |
def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase: | |
# FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader | |
return _ReaderView(file, offset, length) | |
def _normalize_device_info(device_type: str, device_id: int) -> str: | |
"""Device info normalization.""" | |
if device_type == "cpu": | |
return "cpu" | |
return f"{device_type}:{device_id}" | |
# TODO: integrate with distributed logging flag | |
ENABLE_PROFILE = False | |
def _profile(): | |
# Only log the profiling when it is enable and is on rank0 or dist is not | |
# avaiable. | |
if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0): | |
profiler = cProfile.Profile() | |
profiler.enable() | |
try: | |
yield | |
finally: | |
profiler.disable() | |
stats = Stats(profiler) | |
stats.sort_stats("time").print_stats(10) | |
else: | |
yield | |
def _api_bc_check(func): | |
def inner_func(*args, **kwargs) -> Any: | |
if len(args) == 2: | |
warnings.warn( | |
f"The argument order of {func.__name__} has been changed. " | |
"Please check the document to avoid future breakages." | |
) | |
sig = inspect.signature(func) | |
kwonlyargs = [ | |
p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY | |
] | |
if "storage_writer" in kwonlyargs: | |
assert "storage_writer" not in kwargs, (args, kwargs) | |
kwargs["storage_writer"] = args[1] | |
elif "storage_reader" in kwonlyargs: | |
assert "storage_reader" not in kwargs, (args, kwargs) | |
kwargs["storage_reader"] = args[1] | |
else: | |
raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}") | |
return func(args[0], **kwargs) | |
else: | |
return func(*args, **kwargs) | |
return inner_func | |