Spaces:
Runtime error
Runtime error
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
from types import MethodType | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
from .state import PartialState | |
from .utils import ( | |
calculate_maximum_sizes, | |
convert_bytes, | |
copy_tensor_to_devices, | |
ignorant_find_batch_size, | |
infer_auto_device_map, | |
is_pippy_available, | |
pad_input_tensors, | |
send_to_device, | |
) | |
def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None): | |
""" | |
Calculates the device map for `model` with an offset for PiPPy | |
""" | |
if num_processes == 1: | |
return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False) | |
if max_memory is None: | |
model_size, shared = calculate_maximum_sizes(model) | |
# Split into `n` chunks for each GPU | |
memory = (model_size + shared[0]) / num_processes | |
memory = convert_bytes(memory) | |
value, ending = memory.split(" ") | |
# Add a chunk to deal with potential extra shared memory instances | |
memory = math.ceil(float(value)) * 1.1 | |
memory = f"{memory} {ending}" | |
max_memory = {i: memory for i in range(num_processes)} | |
device_map = infer_auto_device_map( | |
model, | |
max_memory=max_memory, | |
no_split_module_classes=no_split_module_classes, | |
clean_result=False, | |
) | |
return device_map | |
def find_pippy_batch_size(args, kwargs): | |
found_batch_size = None | |
if args is not None: | |
for arg in args: | |
found_batch_size = ignorant_find_batch_size(arg) | |
if found_batch_size is not None: | |
break | |
if kwargs is not None and found_batch_size is None: | |
for kwarg in kwargs.values(): | |
found_batch_size = ignorant_find_batch_size(kwarg) | |
if found_batch_size is not None: | |
break | |
return found_batch_size | |
def build_pipeline(model, split_points, args, kwargs, num_chunks): | |
""" | |
Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing | |
in needed `args` and `kwargs` as the model needs on the CPU. | |
Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use | |
`AcceleratorState.num_processes` | |
""" | |
# Note: We import here to reduce import time from general modules, and isolate outside dependencies | |
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points | |
from pippy.PipelineStage import PipelineStage | |
# We need to annotate the split points in the model for PiPPy | |
state = PartialState() | |
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points}) | |
found_batch_size = find_pippy_batch_size(args, kwargs) | |
if found_batch_size != num_chunks: | |
if args is not None: | |
args = pad_input_tensors(args, found_batch_size, num_chunks) | |
if kwargs is not None: | |
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) | |
pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs) | |
stage = PipelineStage(pipe, state.local_process_index, device=state.device) | |
return stage | |
def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs): | |
state = PartialState() | |
output = None | |
if state.num_processes == 1: | |
output = forward(*args, **kwargs) | |
elif state.is_local_main_process: | |
found_batch_size = find_pippy_batch_size(args, kwargs) | |
if found_batch_size is None: | |
raise ValueError("Could not find batch size from args or kwargs") | |
else: | |
if found_batch_size != num_chunks: | |
args = pad_input_tensors(args, found_batch_size, num_chunks) | |
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) | |
forward(*args, **kwargs) | |
elif state.is_last_process: | |
output = forward() | |
else: | |
forward() | |
if gather_output: | |
# Each node will get a copy of the full output which is only on the last GPU | |
output = copy_tensor_to_devices(output) | |
return output | |
def prepare_pippy( | |
model, | |
split_points: Optional[Union[str, List[str]]] = "auto", | |
no_split_module_classes: Optional[List[str]] = None, | |
example_args: Optional[Tuple[Any]] = (), | |
example_kwargs: Optional[Dict[str, Any]] = None, | |
num_chunks: Optional[int] = None, | |
gather_output: Optional[bool] = False, | |
): | |
""" | |
Wraps `model` for pipeline parallel inference. | |
Args: | |
model (`torch.nn.Module`): | |
A model we want to split for pipeline-parallel inference | |
split_points (`str` or `List[str]`, defaults to 'auto'): | |
How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced | |
split given any model. Should be a list of layer names in the model to split by otherwise. | |
no_split_module_classes (`List[str]`): | |
A list of class names for layers we don't want to be split. | |
example_args (tuple of model inputs): | |
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible. | |
example_kwargs (dict of model inputs) | |
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure | |
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition | |
is true for all cases. | |
num_chunks (`int`, defaults to the number of available GPUs): | |
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but | |
this can be tuned and played with. In general one should have num_chunks >= num_gpus. | |
gather_output (`bool`, defaults to `False`): | |
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs. | |
""" | |
if not is_pippy_available(): | |
raise ImportError( | |
"`pippy` was not found to be installed on your system. Please " | |
"install using `pip install torchpippy` or ensure you have at least version 0.2.0" | |
) | |
state = PartialState() | |
example_args = send_to_device(example_args, "cpu") | |
example_kwargs = send_to_device(example_kwargs, "cpu") | |
if num_chunks is None: | |
num_chunks = state.num_processes | |
if split_points == "auto": | |
device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes) | |
split_points = [] | |
for i in range(1, num_chunks): | |
split_points.append(next(k for k, v in device_map.items() if v == i)) | |
model.hf_split_points = split_points | |
stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks) | |
model._original_forward = model.forward | |
model._original_call = model.__call__ | |
model.pippy_stage = stage | |
model.hf_split_points = split_points | |
def forward(*args, **kwargs): | |
return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs) | |
# To act like a decorator so that it can be popped when doing `extract_model_from_parallel` | |
# Note: creates an infinite recursion loop with `generate` | |
model_forward = MethodType(forward, model) | |
forward.__wrapped__ = model_forward | |
model.forward = forward | |
return model | |