File size: 14,526 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
import abc
import io
from dataclasses import dataclass
from enum import auto, Enum
from functools import reduce
from typing import Any, List, Optional, Tuple, Union

import torch

from .metadata import (
    ChunkStorageMetadata,
    Metadata,
    MetadataIndex,
    STATE_DICT_TYPE,
    TensorProperties,
)


__all__ = [
    "WriteItemType",
    "LoadItemType",
    "TensorWriteData",
    "WriteItem",
    "ReadItem",
    "SavePlan",
    "LoadPlan",
    "SavePlanner",
    "LoadPlanner",
]


class WriteItemType(Enum):
    TENSOR = auto()
    SHARD = auto()
    BYTE_IO = auto()


class LoadItemType(Enum):
    TENSOR = auto()
    BYTE_IO = auto()


@dataclass(frozen=True)
class TensorWriteData:
    chunk: ChunkStorageMetadata
    properties: TensorProperties
    size: torch.Size


@dataclass(frozen=True)
class WriteItem:
    """Dataclass which holds information about what needs to be written to storage."""

    index: MetadataIndex
    type: WriteItemType

    # Value present if it's a tensor write
    tensor_data: Optional[TensorWriteData] = None

    def tensor_storage_size(self) -> Optional[int]:
        """

        Calculates the storage size of the underlying tensor, or None if this is not a tensor write.



        Returns:

            Optional[int] storage size, in bytes of underlying tensor if any.

        """
        if self.tensor_data is None:
            return None

        numels = reduce(lambda x, y: x * y, self.tensor_data.size, 1)
        dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype)
        return numels * dtype_size


@dataclass(frozen=True)
class ReadItem:
    # Read Item
    type: LoadItemType

    # Index into the state_dict
    dest_index: MetadataIndex
    # Offsets into destination tensor
    dest_offsets: torch.Size

    # Index into the checkpoint
    storage_index: MetadataIndex
    # Offset into the checkpoint data
    storage_offsets: torch.Size

    # Size of the hypercube to copy
    lengths: torch.Size


@dataclass(frozen=True)
class SavePlan:
    items: List[WriteItem]
    storage_data: Any = None
    planner_data: Any = None


@dataclass
class LoadPlan:
    items: List[ReadItem]
    storage_data: Any = None
    planner_data: Any = None


class SavePlanner(abc.ABC):
    """

    Abstract class defining the protocol used by save_state_dict to plan the save process.



    SavePlanners are stateful objects that can be used to customize the whole save process.



    SavePlanner acts as an access proxy to the state_dict, so any transformation done to it

    will be visible to the whole process.



    A planner subclass can expect the following sequence of calls during save_state_dict:



    1) set_up_planner - called on all ranks.

        Signals the start of a checkpoint save.



    2) create_local_plan - called on all ranks.

        Process the state_dict and produces a `SavePlan` that will be sent for global planning.



    3) create_global_plan - called on the coordinator rank only.

        Takes the SavePlan from all ranks and make any global decision.



    4) finish_plan - called on all ranks.

        This gives each rank a chance to adjust to global planning decisions.



    5) resolve_data - called multiple times on each rank

        Lookups a value on the `state_dict` for the storage layer to write.



    Users are recommended to extend DefaultSavePlanner instead of this interface directly as

    most changes can be expressed by changes in a single method.



    There are 3 usual patterns of extension:



    Rewriting state_dict. This is the simplest way to extend the save process as it

    doesn't requite understanding the intrincacies of how SavePlan works:



    >>> # xdoctest: +SKIP("undefined vars")

    >>> class RenamePlanner(DefaultSavePlanner):

    >>>     def set_up_planner(self, state_dict, is_coordinator):

    >>>         # prefix all keys with `foo_``

    >>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)



    Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted



    >>> # xdoctest: +SKIP("undefined vars")

    >>> class FP16Planner(DefaultSavePlanner):

    >>>     def create_local_plan(self):

    >>>         plan = super().create_local_plan()

    >>>         for p in plan:

    >>>             if p.tensor_data is not None:

    >>>                 p.tensor_data.properties.dtype = torch.float16

    >>>         return plan

    >>>

    >>>     def resolve_data(self, write_item):

    >>>         item = super().resolve_data(write_item)

    >>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)



    Using the global planning step to make central decisions that can't be made individually by each rank



    >>> # xdoctest: +SKIP("undefined vars")

    >>> from itertools import islice

    >>> from dataclasses import replace

    >>> class DDPLoadBalancingPlanner(DefaultSavePlanner):

    >>>     # This uses the default local plan behavior of having all non-sharded writes in rank 0

    >>>     # This sample doesn't handle ShardedTensors

    >>>     def create_global_plan(self, all_plans):

    >>>         def chunk(it, size):

    >>>             it = iter(it)

    >>>         return list(iter(lambda: tuple(islice(it, size)), ()))

    >>>         all_plans = [

    >>>             replace(plan, items=items) for plan, items in

    >>>                 zip(all_plans, chunk(all_plans[0].items, len(all_plans)))

    >>>         ]

    >>>         return super().create_global_plan(all_plans)



    Finally, some planners need to save additional metadata in the checkpoint, this is

    accomplished by having each rank contribute their data items in the local plan and

    the global planner aggregate them:



    >>> # xdoctest: +SKIP("undefined vars")

    >>> class SaveExtraDataPlanner(DefaultSavePlanner):

    >>>     def create_local_plan(self) -> SavePlan:

    >>>         plan = super().create_local_plan()

    >>>         return replace(plan, planner_data="per-rank-data")

    >>>

    >>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:

    >>>         global_plan, metadata = super().create_global_plan(all_plans)

    >>>         merged_data = [p.planner_data for p in global_plan]

    >>>         metadata = replace(metadata, planner_data=merged_data)

    >>>         return global_plan, metadata

    """

    @abc.abstractmethod
    def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
        """

        Initialize this planner to save ``state_dict``.



        Implementations should save those values as they won't be provided lated in the save process.



        This is called on all ranks.

        """
        pass

    @abc.abstractmethod
    def create_local_plan(self) -> SavePlan:
        """

        Compute the save plan for the current rank.



        This will be aggregated and passed to create_global_plan.

        Planner specific data can be passed through SavePlan::planner_data.



        This is called on all ranks.

        """
        pass

    @abc.abstractmethod
    def create_global_plan(

        self, all_plans: List[SavePlan]

    ) -> Tuple[List[SavePlan], Metadata]:
        """

        Compute the global checkpoint plan and return the local plan of each rank.



        This is called on the coordinator rank only.

        """
        pass

    @abc.abstractmethod
    def finish_plan(self, new_plan: SavePlan) -> SavePlan:
        """

        Merge the plan created by `create_local_plan` and the result of `create_global_plan`.



        This is called on all ranks.

        """
        pass

    @abc.abstractmethod
    def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
        """

        Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety.



        Lookup the object associated with ``write_item`` in ``state_dict`` and apply any

        transformation (such as serialization) prior to the storage layer consuming it.



        Called on each rank multiple times, at least once per WriteItem in the final SavePlan.



        This method should be idempotent and thread-save. StorageWriter implementations

        are free to call it as frequently as they need.



        Any transformation that allocates memory should be lazily done when his method

        is called in order to reduce peak memory required by checkpointing.



        When returning tensors, they can be on any device or format, they can be views too.

        It's the storage layer responsibility to figure out how to save them.

        """
        pass


class LoadPlanner:
    """

    Abstract class defining the protocol used by load_state_dict to plan the load process.



    LoadPlanner are stateful objects that can be used to customize the whole load process.



    LoadPlanner acts as an access proxy to the state_dict, so any transformation done to it

    will be visible to the whole process.



    A planner subclass can expect the following sequence of calls during load_state_dict:



    1) set_up_planner - called on all ranks.

        Signals the start of loading a checkpoint.



    2) create_local_plan - called on all ranks.

        Process the state_dict and produces a `LoadPlan` that will be sent for global planning.



    3) create_global_plan - called on the coordinator rank only.

        Takes the LoadPlan from all ranks and make any global decision.



    4) load_bytes - called multiple times on each rank

        This is called once per non-tensor value in state_dict.



    5) resolve_tensor and commit_tensor - called multiple times on each rank

        They are called in pair for each Tensor value in state_dict.



    Users are recommended to extend DefaultLoadPlanner instead of this interface directly as

    most changes can be expressed by changes in a single method.



    There are two usual patterns of extension:



    Rewriting state_dict. This is the simplest way to extend the load process as it

    doesn't requite understanding the intrincacies of how LoadPlan works. We need

    to keep a reference to the original state_dict as load happens in place so

    we need to be able to perform it in place



    >>> # xdoctest: +SKIP("undefined vars")

    >>> class RenamePlanner(DefaultLoadPlanner):

    >>>     def set_up_planner(self, state_dict, metadata, is_coordinator):

    >>>         self.original_state_dict = state_dict

    >>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}

    >>>

    >>>         if self.flatten_sharded_tensors:

    >>>             state_dict = _flatten_sharded_tensors(state_dict)

    >>>

    >>>         if self.flatten_state_dict:

    >>>             state_dict, self.mappings = flatten_state_dict(state_dict)

    >>>

    >>>         self.state_dict = state_dict

    >>>         self.metadata = metadata

    >>>         self.is_coordinator = is_coordinator

    >>>

    >>>     def load_bytes(self, read_item, value):

    >>>         # Remove the "foo_" prefix

    >>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)





    Modifying resolve_tensor and commit_tensor to handle load time transformation.



    >>> # xdoctest: +SKIP("undefined vars")

    >>> class MetaModelMaterialize(DefaultSavePlanner):

    >>>     def resolve_tensor(self, read_item):

    >>>         tensor = super().resolve_tensor(read_item)

    >>>         return torch.empty_like(tensor, device="cpu")

    >>>

    >>>     def commit_tensor(self, read_item, tensor):

    >>>         self.state_dict[read_item.dest_index.fqn] = tensor

    """

    @abc.abstractmethod
    def set_up_planner(

        self,

        state_dict: STATE_DICT_TYPE,

        metadata: Metadata,

        is_coordinator: bool,

    ) -> None:
        """

        Initialize this instance to load data into ``state_dict``.



        . N.B. This is called on every rank.

        """
        pass

    @abc.abstractmethod
    def create_local_plan(self) -> LoadPlan:
        """

        Create a LoadPlan based on state_dict and metadata provided by set_up_planner.



        . N.B. This is called on every rank.

        """
        pass

    @abc.abstractmethod
    def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
        """

        Compute the global load plan and return plans for each rank.



        . N.B. This is called on the coordinator rank only

        """
        pass

    @abc.abstractmethod
    def finish_plan(self, central_plan: LoadPlan) -> LoadPlan:
        """Accept the plan from coordinator and return final LoadPlan."""
        pass

    @abc.abstractmethod
    def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
        """

        Load the item described by ``read_item``and ``value``.



        This method is expected to modify in-place the underlying state_dict.



        The contents of ``value`` are defined by the SavePlanner used to produce

        the checkpoint being loaded.

        """
        pass

    @abc.abstractmethod
    def resolve_tensor(self, read_item: ReadItem) -> torch.Tensor:
        """

        Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`.



        The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents.

        If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data

        back to the one in state_dict.

        """
        pass

    @abc.abstractmethod
    def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
        """

        Call once the StorageReader finished loading data into ``tensor``.



        The provided tensor is the same one returned by the call to ``resolve_tensor``.

        This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to

        copying it back to the one in the state_dict.



        The contents of tensor will follow its device synchronization model.

        """
        pass