Spaces:
Runtime error
Runtime error
File size: 8,055 Bytes
8a6cf24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from types import MethodType
from typing import Any, Dict, List, Optional, Tuple, Union
from .state import PartialState
from .utils import (
calculate_maximum_sizes,
convert_bytes,
copy_tensor_to_devices,
ignorant_find_batch_size,
infer_auto_device_map,
is_pippy_available,
pad_input_tensors,
send_to_device,
)
def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None):
"""
Calculates the device map for `model` with an offset for PiPPy
"""
if num_processes == 1:
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False)
if max_memory is None:
model_size, shared = calculate_maximum_sizes(model)
# Split into `n` chunks for each GPU
memory = (model_size + shared[0]) / num_processes
memory = convert_bytes(memory)
value, ending = memory.split(" ")
# Add a chunk to deal with potential extra shared memory instances
memory = math.ceil(float(value)) * 1.1
memory = f"{memory} {ending}"
max_memory = {i: memory for i in range(num_processes)}
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
clean_result=False,
)
return device_map
def find_pippy_batch_size(args, kwargs):
found_batch_size = None
if args is not None:
for arg in args:
found_batch_size = ignorant_find_batch_size(arg)
if found_batch_size is not None:
break
if kwargs is not None and found_batch_size is None:
for kwarg in kwargs.values():
found_batch_size = ignorant_find_batch_size(kwarg)
if found_batch_size is not None:
break
return found_batch_size
def build_pipeline(model, split_points, args, kwargs, num_chunks):
"""
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing
in needed `args` and `kwargs` as the model needs on the CPU.
Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use
`AcceleratorState.num_processes`
"""
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
from pippy.PipelineStage import PipelineStage
# We need to annotate the split points in the model for PiPPy
state = PartialState()
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size != num_chunks:
if args is not None:
args = pad_input_tensors(args, found_batch_size, num_chunks)
if kwargs is not None:
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs)
stage = PipelineStage(pipe, state.local_process_index, device=state.device)
return stage
def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
state = PartialState()
output = None
if state.num_processes == 1:
output = forward(*args, **kwargs)
elif state.is_local_main_process:
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size is None:
raise ValueError("Could not find batch size from args or kwargs")
else:
if found_batch_size != num_chunks:
args = pad_input_tensors(args, found_batch_size, num_chunks)
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
forward(*args, **kwargs)
elif state.is_last_process:
output = forward()
else:
forward()
if gather_output:
# Each node will get a copy of the full output which is only on the last GPU
output = copy_tensor_to_devices(output)
return output
def prepare_pippy(
model,
split_points: Optional[Union[str, List[str]]] = "auto",
no_split_module_classes: Optional[List[str]] = None,
example_args: Optional[Tuple[Any]] = (),
example_kwargs: Optional[Dict[str, Any]] = None,
num_chunks: Optional[int] = None,
gather_output: Optional[bool] = False,
):
"""
Wraps `model` for pipeline parallel inference.
Args:
model (`torch.nn.Module`):
A model we want to split for pipeline-parallel inference
split_points (`str` or `List[str]`, defaults to 'auto'):
How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced
split given any model. Should be a list of layer names in the model to split by otherwise.
no_split_module_classes (`List[str]`):
A list of class names for layers we don't want to be split.
example_args (tuple of model inputs):
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
example_kwargs (dict of model inputs)
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
is true for all cases.
num_chunks (`int`, defaults to the number of available GPUs):
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
gather_output (`bool`, defaults to `False`):
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
"""
if not is_pippy_available():
raise ImportError(
"`pippy` was not found to be installed on your system. Please "
"install using `pip install torchpippy` or ensure you have at least version 0.2.0"
)
state = PartialState()
example_args = send_to_device(example_args, "cpu")
example_kwargs = send_to_device(example_kwargs, "cpu")
if num_chunks is None:
num_chunks = state.num_processes
if split_points == "auto":
device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes)
split_points = []
for i in range(1, num_chunks):
split_points.append(next(k for k, v in device_map.items() if v == i))
model.hf_split_points = split_points
stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks)
model._original_forward = model.forward
model._original_call = model.__call__
model.pippy_stage = stage
model.hf_split_points = split_points
def forward(*args, **kwargs):
return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs)
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
# Note: creates an infinite recursion loop with `generate`
model_forward = MethodType(forward, model)
forward.__wrapped__ = model_forward
model.forward = forward
return model
|