File size: 7,715 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""Manipulation of micro-batches."""
import typing
from typing import Any, Callable, List, Union, cast, Sequence

import torch
from torch import Tensor
import torch.cuda.comm

__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"]


Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]]


class NoChunk:
    """

    Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor

    should not be chunked on the batch dimension and instead be replicated

    as-is across all micro-batches. This is useful for tensors which might

    not have any 'batch' semantics for the model.

    """
    def __init__(self, inp: Tensor):
        if not torch.is_tensor(inp):
            raise TypeError(f'NoChunk only supported for tensors, found: {inp}')
        self._tensor = inp

    @property
    def tensor(self):
        return self._tensor


class Batch:
    """

    An abstraction representing a microbatch in the pipeline.

    """

    def __init__(self, values: Union[List[Any], Tensor]) -> None:
        self._values = values
        self.atomic = torch.is_tensor(values)

        # Verify at least on tensor
        if not self.atomic:
            if not any(torch.is_tensor(value) for value in self._values):
                raise TypeError(f'No tensors found in batch: {self._values}')

    @property
    def tensor(self) -> Tensor:
        """Retrieves the underlying tensor."""
        if not self.atomic:
            raise AttributeError("not atomic batch")
        return cast(Tensor, self._values)

    @property
    def values(self):
        """Retrieves the underlying values for the batch"""
        return self._values

    def find_tensor_idx(self):
        """

        Retrieves the index of first tensor found.

        """
        if self.atomic:
            return 0
        for i, value in enumerate(self._values):
            if torch.is_tensor(value):
                return i

        raise TypeError("No tensor found!")

    def get_device(self):
        """

        Retrieves the device for this microbatch.

        """
        if self.atomic:
            return self._values.device  # type: ignore[union-attr]

        for value in self._values:
            if torch.is_tensor(value):
                return value.device

    def call(self, function: Function) -> "Batch":
        """Calls a function on the microbatch. It also wraps

        the output with :class:`Batch`.

        """
        if self.atomic:
            return Batch(function(self._values))
        else:
            return Batch(function(*self._values))

    def __repr__(self) -> str:
        return f"Batch[atomic={self.atomic!r}]({self._values!r})"

    def __iter__(self):
        if self.atomic:
            yield self._values
        else:
            yield from self._values

    def __len__(self) -> int:
        return 1 if self.atomic else len(self._values)

    def __getitem__(self, index: int):
        if not self.atomic:
            return self._values[index]

        if index != 0:
            raise IndexError("atomic batch allows index 0 only")

        return self._values

    # NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload".
    @typing.overload
    def __setitem__(self, index: int, value: Tensor) -> None:
        ...

    @typing.overload
    def __setitem__(self, index: slice, value: Tensors) -> None:
        ...

    def __setitem__(self, index: Union[int, slice], value) -> None:
        if isinstance(index, int):
            self._setitem_by_index(index, value)
        else:
            self._setitem_by_slice(index, value)

    def _setitem_by_index(self, index: int, value) -> None:
        if not self.atomic:
            i = index
            self._values = self._values[:i] + (value,) + self._values[i + 1 :]  # type: ignore[operator]
            return

        if index != 0:
            raise IndexError("atomic batch allows index 0 only")

        self._values = value

    def _setitem_by_slice(self, index: slice, value) -> None:
        if not (index.start is index.stop is index.step is None):  # noqa: E714
            raise NotImplementedError("only slice [:] supported")

        if not self.atomic:
            self._values = value
            return

        if len(value) != 1:
            raise IndexError("atomic batch cannot be replaced with multiple tensors")

        self._values = value[0]


def check(first_device, *inputs) -> None:
    """

    Checks whether the input contains at least one tensor and each tensor is

    on the same device as the first partition.



    Raises:

        ValueError: input does not contain at least one tensor



    """

    if not any(torch.is_tensor(input) for input in inputs):
        raise TypeError(f'inputs do not have any tensors: {inputs}')
    if any(torch.is_tensor(input) and input.device != first_device for input in inputs):
        raise ValueError('All inputs should be on the same device as the first partition')


def scatter(*inputs, chunks: int) -> List[Batch]:
    """Splits an input mini-batch into multiple micro-batches."""
    if len(inputs) == 1 and isinstance(inputs[0], Tensor):
        return [Batch(x) for x in inputs[0].chunk(chunks)]

    batches: List[Any] = [[] for _ in range(chunks)]
    # Actual number of chunks produced
    num_chunks = -1
    for input in inputs:
        if torch.is_tensor(input):
            # Chunk only tensors.
            tensors = input.chunk(chunks)

            # Validate number of chunks equal across all inputs.
            if num_chunks != -1 and num_chunks != len(tensors):
                raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}')
            num_chunks = len(tensors)

            for i, tensor in enumerate(tensors):
                batches[i].append(tensor)
        else:
            # Replicate non-tensors or tensors wrapped with 'NoChunk'.
            for i in range(chunks):
                if isinstance(input, NoChunk):
                    # Extract the tensor out.
                    batches[i].append(input.tensor)
                else:
                    batches[i].append(input)

    # Truncate to actual number of chunks
    batches = batches[:num_chunks]

    return [Batch(x) for x in batches]


def gather(outputs: List[Batch]):
    """Concatenates output micro-batches into a mini-batch."""
    output: Any

    if outputs[0].atomic:
        tensors = tuple(b.tensor for b in outputs)
        output = torch.cat(tensors)
    else:
        output_buf: List[Any] = []
        for i in range(len(outputs[0])):
            output_type = type(outputs[0][i])
            current_outputs = []
            for batch in outputs:
                if output_type != type(batch[i]):
                    raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}')
                current_outputs.append(batch[i])

            if torch.is_tensor(outputs[0][i]):
                output_buf.append(torch.cat(current_outputs))
            else:
                output_buf.append(current_outputs)

        output = tuple(output_buf)

    return output