Spaces:
Running
Running
# Copyright 2019 Kakao Brain | |
# | |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""Manipulation of micro-batches.""" | |
import typing | |
from typing import Any, Callable, List, Union, cast, Sequence | |
import torch | |
from torch import Tensor | |
import torch.cuda.comm | |
__all__: List[str] = ["NoChunk", "Batch", "check", "scatter", "gather"] | |
Tensors = Sequence[Tensor] | |
TensorOrTensors = Union[Tensor, Tensors] | |
Function = Callable[[TensorOrTensors], Union[List[Any], Tensor]] | |
class NoChunk: | |
""" | |
Wrapper for a Tensor in :meth:`Pipe.forward` indicating that the tensor | |
should not be chunked on the batch dimension and instead be replicated | |
as-is across all micro-batches. This is useful for tensors which might | |
not have any 'batch' semantics for the model. | |
""" | |
def __init__(self, inp: Tensor): | |
if not torch.is_tensor(inp): | |
raise TypeError(f'NoChunk only supported for tensors, found: {inp}') | |
self._tensor = inp | |
def tensor(self): | |
return self._tensor | |
class Batch: | |
""" | |
An abstraction representing a microbatch in the pipeline. | |
""" | |
def __init__(self, values: Union[List[Any], Tensor]) -> None: | |
self._values = values | |
self.atomic = torch.is_tensor(values) | |
# Verify at least on tensor | |
if not self.atomic: | |
if not any(torch.is_tensor(value) for value in self._values): | |
raise TypeError(f'No tensors found in batch: {self._values}') | |
def tensor(self) -> Tensor: | |
"""Retrieves the underlying tensor.""" | |
if not self.atomic: | |
raise AttributeError("not atomic batch") | |
return cast(Tensor, self._values) | |
def values(self): | |
"""Retrieves the underlying values for the batch""" | |
return self._values | |
def find_tensor_idx(self): | |
""" | |
Retrieves the index of first tensor found. | |
""" | |
if self.atomic: | |
return 0 | |
for i, value in enumerate(self._values): | |
if torch.is_tensor(value): | |
return i | |
raise TypeError("No tensor found!") | |
def get_device(self): | |
""" | |
Retrieves the device for this microbatch. | |
""" | |
if self.atomic: | |
return self._values.device # type: ignore[union-attr] | |
for value in self._values: | |
if torch.is_tensor(value): | |
return value.device | |
def call(self, function: Function) -> "Batch": | |
"""Calls a function on the microbatch. It also wraps | |
the output with :class:`Batch`. | |
""" | |
if self.atomic: | |
return Batch(function(self._values)) | |
else: | |
return Batch(function(*self._values)) | |
def __repr__(self) -> str: | |
return f"Batch[atomic={self.atomic!r}]({self._values!r})" | |
def __iter__(self): | |
if self.atomic: | |
yield self._values | |
else: | |
yield from self._values | |
def __len__(self) -> int: | |
return 1 if self.atomic else len(self._values) | |
def __getitem__(self, index: int): | |
if not self.atomic: | |
return self._values[index] | |
if index != 0: | |
raise IndexError("atomic batch allows index 0 only") | |
return self._values | |
# NOTE(sublee): pyflakes can't detect "overload" instead of "typing.overload". | |
def __setitem__(self, index: int, value: Tensor) -> None: | |
... | |
def __setitem__(self, index: slice, value: Tensors) -> None: | |
... | |
def __setitem__(self, index: Union[int, slice], value) -> None: | |
if isinstance(index, int): | |
self._setitem_by_index(index, value) | |
else: | |
self._setitem_by_slice(index, value) | |
def _setitem_by_index(self, index: int, value) -> None: | |
if not self.atomic: | |
i = index | |
self._values = self._values[:i] + (value,) + self._values[i + 1 :] # type: ignore[operator] | |
return | |
if index != 0: | |
raise IndexError("atomic batch allows index 0 only") | |
self._values = value | |
def _setitem_by_slice(self, index: slice, value) -> None: | |
if not (index.start is index.stop is index.step is None): # noqa: E714 | |
raise NotImplementedError("only slice [:] supported") | |
if not self.atomic: | |
self._values = value | |
return | |
if len(value) != 1: | |
raise IndexError("atomic batch cannot be replaced with multiple tensors") | |
self._values = value[0] | |
def check(first_device, *inputs) -> None: | |
""" | |
Checks whether the input contains at least one tensor and each tensor is | |
on the same device as the first partition. | |
Raises: | |
ValueError: input does not contain at least one tensor | |
""" | |
if not any(torch.is_tensor(input) for input in inputs): | |
raise TypeError(f'inputs do not have any tensors: {inputs}') | |
if any(torch.is_tensor(input) and input.device != first_device for input in inputs): | |
raise ValueError('All inputs should be on the same device as the first partition') | |
def scatter(*inputs, chunks: int) -> List[Batch]: | |
"""Splits an input mini-batch into multiple micro-batches.""" | |
if len(inputs) == 1 and isinstance(inputs[0], Tensor): | |
return [Batch(x) for x in inputs[0].chunk(chunks)] | |
batches: List[Any] = [[] for _ in range(chunks)] | |
# Actual number of chunks produced | |
num_chunks = -1 | |
for input in inputs: | |
if torch.is_tensor(input): | |
# Chunk only tensors. | |
tensors = input.chunk(chunks) | |
# Validate number of chunks equal across all inputs. | |
if num_chunks != -1 and num_chunks != len(tensors): | |
raise RuntimeError(f'Found different number of chunks produced for inputs: {num_chunks} and {len(tensors)}') | |
num_chunks = len(tensors) | |
for i, tensor in enumerate(tensors): | |
batches[i].append(tensor) | |
else: | |
# Replicate non-tensors or tensors wrapped with 'NoChunk'. | |
for i in range(chunks): | |
if isinstance(input, NoChunk): | |
# Extract the tensor out. | |
batches[i].append(input.tensor) | |
else: | |
batches[i].append(input) | |
# Truncate to actual number of chunks | |
batches = batches[:num_chunks] | |
return [Batch(x) for x in batches] | |
def gather(outputs: List[Batch]): | |
"""Concatenates output micro-batches into a mini-batch.""" | |
output: Any | |
if outputs[0].atomic: | |
tensors = tuple(b.tensor for b in outputs) | |
output = torch.cat(tensors) | |
else: | |
output_buf: List[Any] = [] | |
for i in range(len(outputs[0])): | |
output_type = type(outputs[0][i]) | |
current_outputs = [] | |
for batch in outputs: | |
if output_type != type(batch[i]): | |
raise TypeError(f'Types for microbatch outputs do not match, found: {output_type} and {type(batch[i])}') | |
current_outputs.append(batch[i]) | |
if torch.is_tensor(outputs[0][i]): | |
output_buf.append(torch.cat(current_outputs)) | |
else: | |
output_buf.append(current_outputs) | |
output = tuple(output_buf) | |
return output | |