Spaces:
Sleeping
Sleeping
| import dataclasses | |
| import traceback | |
| from typing import Any, Callable, Container, Dict, List, Optional, OrderedDict, Tuple, TypeVar, overload | |
| import torch | |
| import torch.distributed as dist | |
| from torch import nn | |
| from torch.nn.parallel._functions import _get_stream | |
| from torch.nn.parallel.scatter_gather import _is_namedtuple | |
| from torch.nn.utils.rnn import PackedSequence | |
| __all__ = [] # type: ignore[var-annotated] | |
| def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]: | |
| """ | |
| Turn argument list into separate key list and value list (unpack_kwargs does the opposite). | |
| Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70 | |
| Usage:: | |
| kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) | |
| assert kwarg_keys == ("a", "b") | |
| assert flat_args == (1, 2, 3, 4) | |
| args, kwargs = unpack_kwargs(kwarg_keys, flat_args) | |
| assert args == (1, 2) | |
| assert kwargs == {"a": 3, "b": 4} | |
| Returns: | |
| Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives | |
| gives both positional args and kwarg values, where the positional args | |
| proceed kwarg values and kwarg values are ordered consistently with the | |
| kwarg keys. The second tuple element gives the kwarg keys. | |
| The second tuple element's length is at most the first tuple element's length. | |
| """ | |
| kwarg_keys: List[str] = [] | |
| flat_args: List[Any] = list(args) | |
| for k, v in kwargs.items(): | |
| kwarg_keys.append(k) | |
| flat_args.append(v) | |
| return tuple(flat_args), tuple(kwarg_keys) | |
| def _cast_forward_inputs( | |
| dtype: Optional[torch.dtype], | |
| *args: Any, | |
| **kwargs: Any, | |
| ) -> Tuple[Any, Any]: | |
| """ | |
| Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``. | |
| This respects the existing ``requires_grad`` on the tensors. | |
| """ | |
| if dtype is None: | |
| return args, kwargs | |
| def cast_fn(x: torch.Tensor) -> torch.Tensor: | |
| if not torch.is_floating_point(x) or x.dtype == dtype: | |
| return x | |
| return x.to(dtype) | |
| return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs)) | |
| def _unpack_kwargs(flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: | |
| """See _pack_kwargs.""" | |
| assert len(kwarg_keys) <= len( | |
| flat_args | |
| ), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" | |
| if len(kwarg_keys) == 0: | |
| return flat_args, {} | |
| args = flat_args[: -len(kwarg_keys)] | |
| kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :])) | |
| return args, kwargs | |
| S = TypeVar("S", dict, list, tuple) | |
| T = TypeVar("T", torch.Tensor, PackedSequence) | |
| def _recursive_to(inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> List[S]: | |
| ... | |
| def _recursive_to(inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool) -> Tuple[T]: | |
| ... | |
| def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies): | |
| r"""Recursively moves input to the target_device.""" | |
| def to_map(obj): | |
| if isinstance(obj, (torch.Tensor, PackedSequence)): | |
| device = obj.data.device if isinstance(obj, PackedSequence) else obj.device | |
| if device == target_device: | |
| return (obj,) | |
| if not use_side_stream_for_tensor_copies: | |
| return (obj.to(target_device),) | |
| else: | |
| # If the custom module is not registered to torch, stream is not used for acceleration | |
| device_mod = getattr(torch, device.type, None) | |
| if device.type == "cpu" or device_mod is None: | |
| return (obj.to(target_device),) | |
| # Perform CPU -> target_device copies in a background stream. This code is | |
| # motivated from similar logic in torch/nn/parallel/_functions.py | |
| stream = _get_stream(target_device) | |
| with device_mod.stream(stream): | |
| output = obj.to(target_device) | |
| # synchronize with the copy stream | |
| with device_mod.device(target_device.index): | |
| current_stream = device_mod.current_stream() | |
| # Sync the current stream with the copy stream | |
| current_stream.wait_stream(stream) | |
| # Ensure tensor memory is not reused until work on | |
| # main stream is complete | |
| if isinstance(obj, PackedSequence): | |
| output.data.record_stream(current_stream) # type: ignore[arg-type] | |
| else: | |
| assert isinstance(output, torch.Tensor) | |
| output.record_stream(current_stream) # type: ignore[arg-type] | |
| return (output,) | |
| if _is_namedtuple(obj): | |
| return [type(obj)(*args) for args in zip(*map(to_map, obj))] | |
| if isinstance(obj, tuple) and len(obj) > 0: | |
| return list(zip(*map(to_map, obj))) | |
| if isinstance(obj, list) and len(obj) > 0: | |
| return [list(i) for i in zip(*map(to_map, obj))] | |
| if isinstance(obj, dict) and len(obj) > 0: | |
| return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] | |
| return [obj] | |
| # Avoid reference cycle | |
| try: | |
| res = to_map(inputs) | |
| finally: | |
| to_map = None # type: ignore[assignment] | |
| return res | |
| def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None: | |
| """Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed.""" | |
| if not cond: | |
| print(s) | |
| traceback.print_stack() | |
| if raise_assertion_error: | |
| raise AssertionError(s) | |
| def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None: | |
| """ | |
| Allocate storage for ``tensor`` with the given size. | |
| Returns: | |
| bool: ``True`` if this method allocated storage and ``False`` if the | |
| storage was already allocated. | |
| """ | |
| with torch.no_grad(): | |
| if ( | |
| not torch.distributed._functional_collectives.is_torchdynamo_compiling() | |
| ): | |
| already_allocated = tensor._typed_storage()._size() == size.numel() | |
| if not already_allocated: | |
| tensor_storage_size = tensor._typed_storage()._size() | |
| _p_assert( | |
| tensor_storage_size == 0, | |
| "Tensor storage should have been resized to be 0 but got PLACEHOLDEr", | |
| ) | |
| tensor._typed_storage()._resize_(size.numel()) | |
| def _free_storage(tensor: torch.Tensor): | |
| """ | |
| Frees the underlying storage of ``tensor``. | |
| Returns: | |
| bool: ``True`` if the method freed the storage and ``False`` if the | |
| storage was already freed. | |
| """ | |
| with torch.no_grad(): | |
| if ( | |
| not torch.distributed._functional_collectives.is_torchdynamo_compiling() | |
| ): | |
| already_freed = tensor._typed_storage()._size() == 0 | |
| if not already_freed: | |
| _p_assert( | |
| tensor.storage_offset() == 0, | |
| "Freeing a tensor's storage is unsafe when it is not the sole occupant\n" | |
| f"storage offset: {tensor.storage_offset()}\n" | |
| f"storage size: {tensor._typed_storage()._size()}\n" | |
| f"tensor shape: {tensor.shape}", | |
| ) | |
| tensor._typed_storage()._resize_(0) | |
| Q = TypeVar("Q") | |
| R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any) | |
| def _apply_to_tensors(fn: Callable[[torch.Tensor], Q], container: torch.Tensor) -> Q: | |
| ... | |
| def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R: | |
| ... | |
| def _apply_to_tensors(fn, container): | |
| """Recursively apply to all tensor in different kinds of container types.""" | |
| def apply(x): | |
| if isinstance(x, torch.Tensor): | |
| return fn(x) | |
| elif hasattr(x, "__dataclass_fields__"): | |
| dc = dataclasses.replace(x) | |
| for f in dataclasses.fields(dc): | |
| name = f.name | |
| setattr(dc, name, apply(getattr(dc, name))) | |
| return dc | |
| elif isinstance(x, OrderedDict): | |
| od = x.__class__() | |
| for key, value in x.items(): | |
| od[key] = apply(value) | |
| return od | |
| elif isinstance(x, PackedSequence): | |
| apply(x.data) | |
| return x | |
| elif isinstance(x, dict): | |
| return {key: apply(value) for key, value in x.items()} | |
| elif _is_namedtuple(x): | |
| res = (apply(el) for el in x) | |
| return type(x)(*res) | |
| elif isinstance(x, (list, tuple, set)): | |
| return type(x)(apply(el) for el in x) | |
| else: | |
| return x | |
| return apply(container) | |
| def _to_kwargs( | |
| inputs: Tuple[Any, ...], | |
| kwargs: Optional[Dict[str, Any]], | |
| target_device: torch.device, | |
| use_side_stream_for_tensor_copies: bool, | |
| ) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]: | |
| moved_inputs = ( | |
| _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies) | |
| if inputs | |
| else [] | |
| ) | |
| moved_kwargs = ( | |
| _recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies) | |
| if kwargs | |
| else [] | |
| ) | |
| if len(moved_inputs) < len(moved_kwargs): | |
| moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))]) | |
| elif len(moved_kwargs) < len(moved_inputs): | |
| moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))]) | |
| return tuple(moved_inputs), tuple(moved_kwargs) | |
| def _verify_param_shape_across_processes( | |
| process_group: dist.ProcessGroup, tensors: List[torch.Tensor], logger: Optional[dist.Logger] = None | |
| ): | |
| return dist._verify_params_across_processes(process_group, tensors, logger) | |
| def _sync_module_states( | |
| module: nn.Module, | |
| process_group: dist.ProcessGroup, | |
| broadcast_bucket_size: int, | |
| src: int, | |
| params_and_buffers_to_ignore: Container[str], | |
| broadcast_buffers: bool = True, | |
| ) -> None: | |
| """ | |
| Sync ``module``'s parameters and buffers state. | |
| Syncs ``module``'s parameters and buffers state so that all ranks contain | |
| the same module state across all ranks. Note that this API assumes that all | |
| parameter shapes are consistent before running the synchronization. This can | |
| be checked with ``_verify_param_shape_across_processes``. | |
| """ | |
| module_states: List[torch.Tensor] = [] | |
| for name, param in module.named_parameters(): | |
| if name not in params_and_buffers_to_ignore: | |
| module_states.append(param.detach()) | |
| if broadcast_buffers: | |
| for name, buffer in module.named_buffers(): | |
| if name not in params_and_buffers_to_ignore: | |
| module_states.append(buffer.detach()) | |
| _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src) | |
| def _sync_params_and_buffers( | |
| process_group: dist.ProcessGroup, | |
| module_states: List[torch.Tensor], | |
| broadcast_bucket_size: int, | |
| src: int, | |
| ) -> None: | |
| """Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0.""" | |
| if len(module_states) > 0: | |
| dist._broadcast_coalesced( | |
| process_group, module_states, broadcast_bucket_size, src | |
| ) | |
| def _replace_by_prefix( | |
| state_dict: Dict[str, Any], | |
| old_prefix: str, | |
| new_prefix: str, | |
| ) -> None: | |
| """ | |
| Replace all keys that match a given old_prefix with a new_prefix (in-place). | |
| Usage:: | |
| state_dict = {"layer.xyz": torch.tensor(1)} | |
| replace_by_prefix_(state_dict, "layer.", "module.layer.") | |
| assert state_dict == {"module.layer.xyz": torch.tensor(1)} | |
| """ | |
| if old_prefix == new_prefix: | |
| raise ValueError("old_prefix and new_prefix must be distinct") | |
| for key in list(state_dict.keys()): | |
| if not key.startswith(old_prefix): | |
| continue | |
| new_key = new_prefix + key[len(old_prefix) :] | |
| state_dict[new_key] = state_dict[key] | |
| del state_dict[key] | |
| def _data_ptr_allocated(tensor: torch.Tensor) -> bool: | |
| return tensor.untyped_storage().data_ptr() > 0 | |