File size: 8,055 Bytes
8a6cf24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from types import MethodType
from typing import Any, Dict, List, Optional, Tuple, Union

from .state import PartialState
from .utils import (
    calculate_maximum_sizes,
    convert_bytes,
    copy_tensor_to_devices,
    ignorant_find_batch_size,
    infer_auto_device_map,
    is_pippy_available,
    pad_input_tensors,
    send_to_device,
)


def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
    """
    Calculates the device map for `model` with an offset for PiPPy
    """
    if num_processes == 1:
        return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
    if max_memory is None:
        model_size, shared = calculate_maximum_sizes(model)

        # Split into `n` chunks for each GPU
        memory = (model_size + shared[0]) / num_processes
        memory = convert_bytes(memory)
        value, ending = memory.split(" ")

        # Add a chunk to deal with potential extra shared memory instances
        memory = math.ceil(float(value)) * 1.1
        memory = f"{memory} {ending}"
        max_memory = {i: memory for i in range(num_processes)}
    device_map = infer_auto_device_map(
        model,
        max_memory=max_memory,
        no_split_module_classes=no_split_module_classes,
        clean_result=False,
    )
    return device_map


def find_pippy_batch_size(args, kwargs):
    found_batch_size = None
    if args is not None:
        for arg in args:
            found_batch_size = ignorant_find_batch_size(arg)
            if found_batch_size is not None:
                break
    if kwargs is not None and found_batch_size is None:
        for kwarg in kwargs.values():
            found_batch_size = ignorant_find_batch_size(kwarg)
            if found_batch_size is not None:
                break
    return found_batch_size


def build_pipeline(model, split_points, args, kwargs, num_chunks):
    """
    Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
    in needed `args` and `kwargs` as the model needs on the CPU.

    Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
    `AcceleratorState.num_processes`
    """
    # Note: We import here to reduce import time from general modules, and isolate outside dependencies
    from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
    from pippy.PipelineStage import PipelineStage

    # We need to annotate the split points in the model for PiPPy
    state = PartialState()
    annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
    found_batch_size = find_pippy_batch_size(args, kwargs)
    if found_batch_size != num_chunks:
        if args is not None:
            args = pad_input_tensors(args, found_batch_size, num_chunks)
        if kwargs is not None:
            kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
    pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs)
    stage = PipelineStage(pipe, state.local_process_index, device=state.device)

    return stage


def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
    state = PartialState()
    output = None

    if state.num_processes == 1:
        output = forward(*args, **kwargs)
    elif state.is_local_main_process:
        found_batch_size = find_pippy_batch_size(args, kwargs)
        if found_batch_size is None:
            raise ValueError("Could not find batch size from args or kwargs")
        else:
            if found_batch_size != num_chunks:
                args = pad_input_tensors(args, found_batch_size, num_chunks)
                kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
        forward(*args, **kwargs)
    elif state.is_last_process:
        output = forward()
    else:
        forward()
    if gather_output:
        # Each node will get a copy of the full output which is only on the last GPU
        output = copy_tensor_to_devices(output)
    return output


def prepare_pippy(
    model,
    split_points: Optional[Union[str, List[str]]] = "auto",
    no_split_module_classes: Optional[List[str]] = None,
    example_args: Optional[Tuple[Any]] = (),
    example_kwargs: Optional[Dict[str, Any]] = None,
    num_chunks: Optional[int] = None,
    gather_output: Optional[bool] = False,
):
    """
    Wraps `model` for pipeline parallel inference.

    Args:
        model (`torch.nn.Module`):
            A model we want to split for pipeline-parallel inference
        split_points (`str` or `List[str]`, defaults to 'auto'):
            How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
            split given any model. Should be a list of layer names in the model to split by otherwise.
        no_split_module_classes (`List[str]`):
            A list of class names for layers we don't want to be split.
        example_args (tuple of model inputs):
            The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
        example_kwargs (dict of model inputs)
            The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
            that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
            is true for all cases.
        num_chunks (`int`, defaults to the number of available GPUs):
            The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
            this can be tuned and played with. In general one should have num_chunks >= num_gpus.
        gather_output (`bool`, defaults to `False`):
            If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
    """
    if not is_pippy_available():
        raise ImportError(
            "`pippy` was not found to be installed on your system. Please "
            "install using `pip install torchpippy` or ensure you have at least version 0.2.0"
        )
    state = PartialState()
    example_args = send_to_device(example_args, "cpu")
    example_kwargs = send_to_device(example_kwargs, "cpu")
    if num_chunks is None:
        num_chunks = state.num_processes
    if split_points == "auto":
        device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
        split_points = []
        for i in range(1, num_chunks):
            split_points.append(next(k for k, v in device_map.items() if v == i))
    model.hf_split_points = split_points
    stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
    model._original_forward = model.forward
    model._original_call = model.__call__
    model.pippy_stage = stage
    model.hf_split_points = split_points

    def forward(*args, **kwargs):
        return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs)

    # To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
    # Note: creates an infinite recursion loop with `generate`
    model_forward = MethodType(forward, model)
    forward.__wrapped__ = model_forward
    model.forward = forward
    return model