Spaces:
Sleeping
Sleeping
from dataclasses import dataclass, field | |
from enum import Enum | |
from typing import Any, Dict, List, Optional, Sequence, Union | |
import torch | |
from torch.distributed.checkpoint.stateful import StatefulT | |
__all__ = [ | |
"ChunkStorageMetadata", | |
"TensorStorageMetadata", | |
"BytesStorageMetadata", | |
"Metadata", | |
"MetadataIndex", | |
"TensorProperties", | |
] | |
class ChunkStorageMetadata: | |
""" | |
Each chunk is expected to have the same properties of the TensorStorageMetadata | |
that includes it. | |
""" | |
offsets: torch.Size | |
sizes: torch.Size | |
class _MEM_FORMAT_ENCODING(Enum): | |
"""Describe the memory format of a tensor.""" | |
TORCH_CONTIGUOUS_FORMAT = 0 | |
TORCH_CHANNELS_LAST = 1 | |
TORCH_PRESERVE_FORMAT = 2 | |
class TensorProperties: | |
"""Properties used to create :class:`Tensor`""" | |
# Regular tensor fields | |
dtype: torch.dtype = field(default_factory=torch.get_default_dtype) | |
# This field is deprecated. | |
layout: torch.layout = field(default=torch.strided) | |
# This field is deprecated. | |
requires_grad: bool = False | |
# This field is deprecated. | |
memory_format: torch.memory_format = field(default=torch.contiguous_format) | |
# This field is deprecated. | |
pin_memory: bool = False | |
def __getstate__(self): | |
# Since torch.memory_format cannot be pickled! | |
memory_format = self.memory_format | |
if memory_format == torch.contiguous_format: | |
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT | |
elif memory_format == torch.channels_last: | |
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST | |
elif memory_format == torch.preserve_format: | |
mem_format_encoding = _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT | |
else: | |
raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") | |
return ( | |
self.dtype, | |
self.layout, | |
self.requires_grad, | |
mem_format_encoding, | |
self.pin_memory, | |
) | |
def __setstate__( | |
self, | |
state, | |
): | |
( | |
self.dtype, | |
self.layout, | |
self.requires_grad, | |
mem_format_encoding, | |
self.pin_memory, | |
) = state | |
if mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: | |
memory_format = torch.contiguous_format | |
elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: | |
memory_format = torch.channels_last | |
elif mem_format_encoding == _MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: | |
memory_format = torch.preserve_format | |
else: | |
raise RuntimeError( | |
f"Invalid torch.memory_format encoding: {mem_format_encoding}" | |
) | |
self.memory_format = memory_format | |
def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": | |
return TensorProperties( | |
dtype=tensor.dtype, | |
layout=tensor.layout, | |
requires_grad=tensor.requires_grad, | |
memory_format=torch.contiguous_format, | |
pin_memory=tensor.is_pinned(), | |
) | |
class TensorStorageMetadata: | |
properties: TensorProperties | |
size: torch.Size | |
chunks: List[ChunkStorageMetadata] | |
class BytesStorageMetadata: | |
pass | |
STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata] | |
STATE_DICT_TYPE = Dict[str, Union[StatefulT, Any]] | |
class Metadata: | |
"""This class represents the metadata of the checkpoint.""" | |
# Keys are the same from the `state_dict` used. | |
state_dict_metadata: Dict[str, STORAGE_TYPES] | |
# It is the responsibility of the planner and storage plugins to ensure | |
# backward compatibility of the planner_data and storage_data. DCP will | |
# also ensure the backward compatibility of the metadata in this file and | |
# the metadata of the built-in planner and storage plugins. | |
planner_data: Any = None | |
storage_data: Any = None | |
class MetadataIndex: | |
"""This class represents a lookup key for items in a state dict or Metadata.""" | |
fqn: str | |
"""Fully Qualified Name of the object""" | |
offset: Optional[torch.Size] = None | |
"""If the object is a tensor, offset into the tensor we're looking for""" | |
index: Optional[int] = field(hash=False, compare=False, default=None) | |
""" | |
Index hint when searching for tensor chunk to speedup lookups (optional) | |
A common representation of a sharded tensor is as a list of chunks so to | |
find the index in such a list you need to linear search it. | |
When constructing an instance of MetadataIndex that points to that list, | |
one can provide the index as a hint and it will be probed first before | |
the linear search and thus making it significantly faster. | |
""" | |
def __init__( | |
self, | |
fqn: str, | |
offset: Optional[Sequence[int]] = None, | |
index: Optional[int] = None, | |
): | |
# We must use object.__setattr__ due to frozen=True | |
object.__setattr__(self, "fqn", fqn) | |
object.__setattr__(self, "index", index) | |
if offset is not None: | |
object.__setattr__(self, "offset", torch.Size(offset)) | |