Spaces:
Running
Running
File size: 9,634 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
import abc
import os
from dataclasses import dataclass
from typing import Any, List, Union
from torch.futures import Future
from .metadata import Metadata, MetadataIndex
from .planner import LoadPlan, LoadPlanner, SavePlan, SavePlanner
__all__ = ["WriteResult", "StorageWriter", "StorageReader"]
@dataclass(frozen=True)
class WriteResult:
index: MetadataIndex
size_in_bytes: int
storage_data: Any
class StorageWriter(abc.ABC):
"""
Interface used by ``save_state_dict`` to write to storage.
One StorageWriter instance acts as both the coordinator and the follower
in a distributed checkpoint. As part of initialization, each instance
is told its role.
A subclass should expect the following sequence of calls.
0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
1) (all ranks) set_up_storage_writer()
2) (all ranks) prepare_local_plan()
3) (coordinator) prepare_global_plan()
4) (all ranks) write_data()
5) (coordinator) finish()
"""
@abc.abstractmethod
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
"""
Calls to indicates a brand new checkpoint write is going to happen.
A checkpoint_id may be present if users set the checkpoint_id for
this checkpoint write. The meaning of the checkpiont_id is
storage-dependent. It can be a path to a folder/file or a key for
a key-value storage.
Args:
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is a key-value store.
(Default: ``None``)
"""
...
@abc.abstractmethod
def set_up_storage_writer(self, is_coordinator: bool) -> None:
"""
Initialize this instance.
Args:
is_coordinator (bool): Whether this instance is responsible for coordinating
the checkpoint.
"""
pass
@abc.abstractmethod
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
"""
Perform storage-specific local planning.
While this method can produce a completely different plan, the recommended
way is to store storage specific data in SavePlan::storage_data.
Args:
plan (SavePlan): The local plan from the ``SavePlanner`` in use.
Returns:
A transformed ``SavePlan`` after storage local planning
"""
pass
@abc.abstractmethod
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
"""
Perform centralized planning of storage.
This method is only called on the coordinator instance.
While this method can produce a completely different plan, the preferred
way is to store storage specific data in SavePlan::storage_data.
Args:
plans: A list of ``SavePlan`` instances, one for each rank.
Returns:
A list of transformed ``SavePlan`` after storage global planning
"""
pass
@abc.abstractmethod
def write_data(
self, plan: SavePlan, planner: SavePlanner
) -> Future[List[WriteResult]]:
"""
Write all items from ``plan`` using ``planner`` to resolve the data.
A subclass should call ``SavePlanner::resolve_data`` on each item
from the plan to get access to the underlying object to write.
Subclasses should lazily call `resolve_data` as it can allocate memory.
In case of tensors, make following assumptions:
- They might be on any device, including not matching the one on ``WriteItem::tensor_data``
- They might be views or not contiguous. Only the projection needs to be saved.
Args:
plan (SavePlan): The save plan to execute.
planner (SavePlanner): Planner object to be used to resolve items to data.
Returns:
A future that completes to a list of WriteResult
"""
pass
@abc.abstractmethod
def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
"""
Write the metadata and marks the current checkpoint as successful.
The actual format/schema used for serializing `metadata` is an
implementation detail. The only requirement is that it's recoverable
in to the same object graph.
Args:
metadata (Metadata): metadata for the new checkpoint
results: A list of WriteResults from all ranks.
Returns:
None
"""
pass
@classmethod
@abc.abstractmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
"""
Check if the given checkpoint_id is supported by the stroage. This allow
us to enable automatic storage selection.
"""
...
class StorageReader(abc.ABC):
"""
Interface used by ``load_state_dict`` to read from storage.
One StorageReader instance acts as both the coordinator and the follower
in a distributed checkpoint. As part of initialization, each instance
is told its role.
A subclass should expected the following sequence of calls by ``load_state_dict``:
0) (all ranks) set checkpoint_id if users pass a valid checkpoint_id.
1) (all ranks) read_metadata()
2) (all ranks) set_up_storage_reader()
3) (all ranks) prepare_local_plan()
4) (coordinator) prepare_global_plan()
5) (all ranks) read_data()
"""
@abc.abstractmethod
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
"""
Calls to indicates a brand new checkpoint read is going to happen.
A checkpoint_id may be present if users set the checkpoint_id for
this checkpoint read. The meaning of the checkpiont_id is
storage-dependent. It can be a path to a folder/file or a key for
a key-value storage.
Args:
checkpoint_id (Union[str, os.PathLike, None]):
The ID of this checkpoint instance. The meaning of the checkpoint_id
depends on the storage. It can be a path to a folder or to a file.
It can also be a key if the storage is more like a key-value store.
(Default: ``None``)
"""
...
@abc.abstractmethod
def read_metadata(self) -> Metadata:
"""
Read the checkpoint metadata.
Returns:
The metadata object associated with the checkpoint being loaded.
"""
pass
@abc.abstractmethod
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
"""
Initialize this instance.
Args:
metadata (Metadata): The metadata schema to use.
is_coordinator (bool): Whether this instance is responsible for coordinating
the checkpoint.
"""
pass
@abc.abstractmethod
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
"""
Perform storage-specific local planning.
While this method can produce a completely different plan, the recommended
way is to store storage specific data in LoadPlan::storage_data.
Args:
plan (LoadPlan): The local plan from the ``LoadPlan`` in use.
Returns:
A transformed ``LoadPlan`` after storage local planning
"""
pass
@abc.abstractmethod
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
"""
Perform centralized planning of storage loading.
This method is only called on the coordinator instance.
While this method can produce a completely different plan, the preferred
way is to store storage specific data in LoadPlan::storage_data.
Args:
plans: A list of ``LoadPlan`` instances, one for each rank.
Returns:
A list of transformed ``LoadPlan`` after storage global planning
"""
pass
@abc.abstractmethod
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
"""
Read all items from ``plan`` using ``planner`` to resolve the data.
A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO
object into the right place.
A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the
tensors that in should load data into.
It's the StorageLayer responsibility to properly schedule any cross device copies
required.
Args:
plan (LoadPlan): The local plan to execute on
planner (LoadPlanner): The planner object to use to resolve items.
Returns:
A future that completes once all reads are finished.
"""
pass
@classmethod
@abc.abstractmethod
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
"""
Check if the given checkpoint_id is supported by the stroage. This allow
us to enable automatic storage selection.
"""
...
|