Spaces:
Running
Running
# Copyright 2019 Kakao Brain | |
# | |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""The pipeline parallelism of Pipe.""" | |
from queue import Queue | |
from types import TracebackType | |
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence | |
import torch | |
from torch import Tensor, nn | |
from torch.autograd.profiler import record_function | |
from .checkpoint import Checkpointing | |
from .copy import Copy, Wait | |
from .dependency import fork, join | |
from .microbatch import Batch | |
from .skip.layout import SkipLayout | |
from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker | |
from .stream import AbstractStream, current_stream, use_device | |
from .worker import Task, create_workers | |
__all__: List[str] = ["Pipeline"] | |
Tensors = Sequence[Tensor] | |
TensorOrTensors = Union[Tensor, Tensors] | |
ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] | |
# Queue is generic only in stubs. | |
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime | |
if TYPE_CHECKING: | |
InQueue = Queue[Optional["Task"]] | |
OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] | |
else: | |
InQueue = Queue | |
OutQueue = Queue | |
def _depend(fork_from: Batch, join_to: Batch) -> None: | |
fork_from_idx = fork_from.find_tensor_idx() | |
join_to_idx = join_to.find_tensor_idx() | |
fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx]) | |
join_to[join_to_idx] = join(join_to[join_to_idx], phony) | |
def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: | |
batch[:] = Copy.apply(prev_stream, next_stream, *batch) | |
# Gradients are only supported for float Tensors. | |
batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) | |
def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: | |
batch[:] = Wait.apply(prev_stream, next_stream, *batch) | |
# Gradients are only supported for float Tensors. | |
batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) | |
def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]: | |
"""Generate schedules for each clock cycle.""" | |
# m: number of micro-batches | |
# n: number of partitions | |
# i: index of micro-batch | |
# j: index of partition | |
# k: clock number | |
# | |
# k (i,j) (i,j) (i,j) | |
# - ----- ----- ----- | |
# 0 (0,0) | |
# 1 (1,0) (0,1) | |
# 2 (2,0) (1,1) (0,2) | |
# 3 (2,1) (1,2) | |
# 4 (2,2) | |
for k in range(m + n - 1): | |
yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] | |
class Pipeline: | |
"""The pipeline parallelism for Pipe.""" | |
def __init__( | |
self, | |
partitions: List[nn.Sequential], | |
devices: List[torch.device], | |
copy_streams: List[List[AbstractStream]], | |
skip_layout: SkipLayout, | |
checkpoint_stop: int, | |
) -> None: | |
self.partitions = partitions | |
self.devices = devices | |
self.copy_streams = copy_streams | |
self.skip_layout = skip_layout | |
self.checkpoint_stop = checkpoint_stop | |
(self.in_queues, self.out_queues) = create_workers(devices) | |
def run(self, batches: List[Batch]) -> None: | |
"""Runs pipeline parallelism. | |
It modifies the given batches in place. | |
""" | |
partitions = self.partitions | |
devices = self.devices | |
skip_layout = self.skip_layout | |
m = len(batches) | |
n = len(partitions) | |
skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] | |
for schedule in _clock_cycles(m, n): | |
self.fence(batches, schedule, skip_trackers) | |
self.compute(batches, schedule, skip_trackers) | |
def fence( | |
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], | |
) -> None: | |
"""Copy micro-batches after computation for the previous micro-batches.""" | |
copy_streams = self.copy_streams | |
skip_layout = self.skip_layout | |
for i, j in schedule: | |
# Ensure that batches[i-1] is executed after batches[i] in | |
# backpropagation by an explicit dependency. | |
if i != 0 and j != 0: | |
_depend(batches[i - 1], batches[i]) | |
next_stream = copy_streams[j][i] | |
for prev_j, ns, name in skip_layout.copy_policy(j): | |
prev_stream = copy_streams[prev_j][i] | |
skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name) | |
if j != 0: | |
prev_stream = copy_streams[j - 1][i] | |
_copy(batches[i], prev_stream, next_stream) | |
def compute( | |
self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], | |
) -> None: | |
"""Run tasks with synchronization to copy streams.""" | |
partitions = self.partitions | |
devices = self.devices | |
copy_streams = self.copy_streams | |
checkpoint_stop = self.checkpoint_stop | |
# Disable checkpointing if in eval mode. | |
if not self.partitions[0].training: | |
checkpoint_stop = 0 | |
n = len(partitions) | |
streams = [current_stream(d) for d in devices] | |
exc_info: Optional[ExcInfo] = None | |
# With checkpointing, the autograd graph looks like this diagram: | |
# βββββββΈβββββββ | |
# β Copy β | |
# βββββββ°βββββββ (fence) | |
# β β β β β β β β β β β β β | |
# β (compute) | |
# βββββββΈβββββββ | |
# β Wait β [1] Synchronize the current stream with the copy stream. | |
# βββββββ°βββββββ | |
# βββββββΈβββββββ | |
# β Checkpoint β [2] Compute a partition within checkpointing. | |
# βββββββ°βββββββ | |
# βββββββΈβββββββ | |
# β Wait β [3] Synchronize the copy stream with the current stream. | |
# βββββββ°βββββββ | |
# β β β β β | |
# β βββββββ΄ββββββ | |
# β β Recompute β [4] Schedule the recomputation at backpropagation. | |
# β βββββββ¬ββββββ | |
# β β β β β | |
# β | |
# β β β β β β β β β β β β β | |
# βββββββΈβββββββ (fence) | |
# β Copy β | |
# βββββββ°βββββββ | |
for i, j in schedule: | |
batch = batches[i] | |
partition = partitions[j] | |
# Synchronize with the copied input. ([1] in the diagram) | |
if j != 0: | |
_wait(batch, copy_streams[j][i], streams[j]) | |
# Determine whether checkpointing or not. | |
checkpoint = i < checkpoint_stop | |
if checkpoint: | |
def function( | |
*inputs, | |
partition: nn.Module = partition, | |
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], | |
chunk_id: int = i, | |
part_id: int = j, | |
) -> TensorOrTensors: | |
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): | |
return partition(*inputs) | |
chk = Checkpointing(function, batch) # type: ignore[arg-type] | |
task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) | |
del function, chk | |
else: | |
def compute( | |
batch: Batch = batch, | |
partition: nn.Module = partition, | |
skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], | |
chunk_id: int = i, | |
part_id: int = j, | |
) -> Batch: | |
with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): | |
return batch.call(partition) | |
task = Task(streams[j], compute=compute, finalize=None) | |
del compute | |
# Compute tasks in parallel. ([2] in the diagram) | |
self.in_queues[j].put(task) | |
for i, j in schedule: | |
ok, payload = self.out_queues[j].get() | |
# Hold the first exception. | |
if exc_info is not None: | |
continue | |
elif not ok: | |
exc_info = cast(ExcInfo, payload) | |
continue | |
task, batch = cast(Tuple[Task, Batch], payload) | |
# The copy stream synchronizes to copy the output. ([3] in the | |
# diagram) | |
if j != n - 1: | |
_wait(batch, streams[j], copy_streams[j][i]) | |
# Finalize tasks. If checkpointing is enabled, here the | |
# recomputation is scheduled at backpropagation. ([4] in the | |
# diagram) | |
with use_device(devices[j]): | |
task.finalize(batch) | |
batches[i] = batch | |
# Fail at the first exception. | |
if exc_info is not None: | |
raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) | |