Spaces:
Paused
Paused
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # reference python implementations for C ops | |
| import torch | |
| from functorch._C import dim as _C | |
| from . import op_properties | |
| from .batch_tensor import _enable_layers | |
| from .tree_map import tree_flatten, tree_map | |
| DimList = _C.DimList | |
| import operator | |
| from functools import reduce | |
| # use dict to avoid writing C++ bindings for set | |
| pointwise = set(op_properties.pointwise) | |
| def prod(x): | |
| return reduce(operator.mul, x, 1) | |
| def _wrap_dim(d, N, keepdim): | |
| from . import Dim | |
| if isinstance(d, Dim): | |
| assert not keepdim, "cannot preserve first-class dimensions with keepdim=True" | |
| return d | |
| elif d >= 0: | |
| return d - N | |
| else: | |
| return d | |
| def _dims(d, N, keepdim, single_dim): | |
| from . import Dim | |
| if isinstance(d, (Dim, int)): | |
| return ltuple((_wrap_dim(d, N, keepdim),)) | |
| assert not single_dim, f"expected a single dimension or int but found: {d}" | |
| return ltuple(_wrap_dim(x, N, keepdim) for x in d) | |
| def _bind_dims_to_size(lhs_size, rhs, lhs_debug): | |
| from . import DimensionMismatchError | |
| not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound) | |
| if len(not_bound) == 1: | |
| idx, d = not_bound[0] | |
| rhs_so_far = prod(r.size for r in rhs if r.is_bound) | |
| if lhs_size % rhs_so_far != 0: | |
| rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) | |
| raise DimensionMismatchError( | |
| f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}" | |
| ) | |
| new_size = lhs_size // rhs_so_far | |
| d.size = new_size | |
| elif len(not_bound) > 1: | |
| rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) | |
| raise DimensionMismatchError( | |
| f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}" | |
| ) | |
| else: | |
| rhs_size = prod(r.size for r in rhs) | |
| if lhs_size != rhs_size: | |
| raise DimensionMismatchError( | |
| f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}" | |
| ) | |
| def _tensor_levels(inp): | |
| from . import _Tensor | |
| if isinstance(inp, _Tensor): | |
| return inp._tensor, llist(inp._levels), inp._has_device | |
| else: | |
| return inp, llist(range(-inp.ndim, 0)), True | |
| def _match_levels(v, from_levels, to_levels): | |
| view = [] | |
| permute = [] | |
| requires_view = False | |
| size = v.size() | |
| for t in to_levels: | |
| try: | |
| idx = from_levels.index(t) | |
| permute.append(idx) | |
| view.append(size[idx]) | |
| except ValueError: | |
| view.append(1) | |
| requires_view = True | |
| if permute != list(range(len(permute))): | |
| v = v.permute(*permute) | |
| if requires_view: | |
| v = v.view(*view) | |
| return v | |
| # make a single dimension positional but do not permute it, | |
| # used to do multi-tensor operators where the dim being acted on | |
| # should not physically move if possible | |
| def _positional_no_permute(self, dim, expand_dim=False): | |
| from . import Tensor | |
| ptensor, levels = self._tensor, llist(self._levels) | |
| try: | |
| idx = levels.index(dim) | |
| except ValueError: | |
| if not expand_dim: | |
| raise | |
| idx = 0 | |
| ptensor = ptensor.expand(dim.size, *ptensor.size()) | |
| levels.insert(0, 0) | |
| idx_batched = 0 | |
| for i in range(idx): | |
| if isinstance(levels[i], int): | |
| levels[i] -= 1 | |
| idx_batched += 1 | |
| levels[idx] = -idx_batched - 1 | |
| return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched | |
| def seq(a, b): | |
| from . import Dim | |
| if isinstance(a, Dim) != isinstance(b, Dim): | |
| return False | |
| if isinstance(a, Dim): | |
| return a is b | |
| else: | |
| return a == b | |
| class isin: | |
| def __contains__(self, item): | |
| for x in self: | |
| if seq(item, x): | |
| return True | |
| return False | |
| def index(self, item): | |
| for i, x in enumerate(self): | |
| if seq(item, x): | |
| return i | |
| raise ValueError | |
| class llist(isin, list): | |
| pass | |
| class ltuple(isin, tuple): | |
| pass | |
| empty_dict = {} | |
| def __torch_function__(self, orig, cls, args, kwargs=empty_dict): | |
| from . import _Tensor, Tensor, TensorLike | |
| from .delayed_mul_tensor import DelayedMulTensor | |
| if orig is torch.Tensor.__mul__: | |
| lhs, rhs = args | |
| if ( | |
| isinstance(lhs, _Tensor) | |
| and isinstance(rhs, _Tensor) | |
| and lhs.ndim == 0 | |
| and rhs.ndim == 0 | |
| ): | |
| return DelayedMulTensor(lhs, rhs) | |
| all_dims = llist() | |
| flat_args, unflatten = tree_flatten((args, kwargs)) | |
| device_holding_tensor = None | |
| for f in flat_args: | |
| if isinstance(f, _Tensor): | |
| if f._has_device: | |
| device_holding_tensor = f._batchtensor | |
| for d in f.dims: | |
| if d not in all_dims: | |
| all_dims.append(d) | |
| def unwrap(t): | |
| if isinstance(t, _Tensor): | |
| r = t._batchtensor | |
| if device_holding_tensor is not None and not t._has_device: | |
| r = r.to(device=device_holding_tensor.device) | |
| return r | |
| return t | |
| if orig in pointwise: | |
| result_levels = llist() | |
| arg_levels = llist() | |
| to_expand = [] | |
| for i, f in enumerate(flat_args): | |
| if isinstance(f, TensorLike): | |
| ptensor, levels, _ = _tensor_levels(f) | |
| if ( | |
| isinstance(f, _Tensor) | |
| and not f._has_device | |
| and device_holding_tensor is not None | |
| ): | |
| ptensor = ptensor.to(device=device_holding_tensor.device) | |
| flat_args[i] = ptensor | |
| for l in levels: | |
| if l not in result_levels: | |
| result_levels.append(l) | |
| to_expand.append((i, levels)) | |
| for i, levels in to_expand: | |
| flat_args[i] = _match_levels(flat_args[i], levels, result_levels) | |
| args, kwargs = unflatten(flat_args) | |
| result = orig(*args, **kwargs) | |
| def wrap(t): | |
| if isinstance(t, TensorLike): | |
| return Tensor.from_positional( | |
| t, result_levels, device_holding_tensor is not None | |
| ) | |
| return t | |
| return tree_map(wrap, result) | |
| else: | |
| def wrap(t): | |
| if isinstance(t, TensorLike): | |
| return Tensor.from_batched(t, device_holding_tensor is not None) | |
| return t | |
| with _enable_layers(all_dims): | |
| print(f"batch_tensor for {orig}") | |
| args, kwargs = unflatten(unwrap(f) for f in flat_args) | |
| result = orig(*args, **kwargs) | |
| # print("END", orig) | |
| return tree_map(wrap, result) | |
| def positional(self, *dims): | |
| from . import Dim, DimensionBindError, Tensor | |
| ptensor, levels = self._tensor, llist(self._levels) | |
| flat_dims = llist() | |
| view = [] | |
| needs_view = False | |
| ndim = self.ndim | |
| for d in dims: | |
| if isinstance(d, DimList): | |
| flat_dims.extend(d) | |
| view.extend(e.size for e in d) | |
| elif isinstance(d, Dim): | |
| flat_dims.append(d) | |
| view.append(d.size) | |
| elif isinstance(d, int): | |
| d = _wrap_dim(d, ndim, False) | |
| flat_dims.append(d) | |
| view.append(ptensor.size(d)) | |
| else: | |
| flat_dims.extend(d) | |
| view.append(prod(e.size for e in d)) | |
| needs_view = True | |
| permute = list(range(len(levels))) | |
| nflat = len(flat_dims) | |
| for i, d in enumerate(flat_dims): | |
| try: | |
| idx = levels.index(d) | |
| except ValueError as e: | |
| raise DimensionBindError( | |
| f"tensor of dimensions {self.dims} does not contain dim {d}" | |
| ) from e | |
| p = permute[idx] | |
| del levels[idx] | |
| del permute[idx] | |
| levels.insert(i, 0) | |
| permute.insert(i, p) | |
| ptensor = ptensor.permute(*permute) | |
| seen = 0 | |
| for i in range(len(levels) - 1, -1, -1): | |
| if isinstance(levels[i], int): | |
| seen += 1 | |
| levels[i] = -seen | |
| result = Tensor.from_positional(ptensor, levels, self._has_device) | |
| if needs_view: | |
| result = result.reshape(*view, *result.size()[len(flat_dims) :]) | |
| return result | |
| def _contains_dim(input): | |
| from . import Dim | |
| for i in input: | |
| if isinstance(i, Dim): | |
| return True | |
| def expand(self, *sizes): | |
| if not _contains_dim(sizes): | |
| return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes)) | |
| dims = sizes | |
| sizes = [d.size for d in dims] + [-1] * self.ndim | |
| self = self.expand(*sizes) | |
| return self[dims] | |
| _not_present = object() | |
| def _getarg(name, offset, args, kwargs, default): | |
| if len(args) > offset: | |
| return args[offset] | |
| return kwargs.get(name, default) | |
| def _patcharg(name, offset, args, kwargs, value): | |
| if len(args) > offset: | |
| args[offset] = value | |
| else: | |
| kwargs[name] = value | |
| def _wrap( | |
| orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True | |
| ): | |
| from . import Dim, Tensor, TensorLike | |
| def fn(self, *args, **kwargs): | |
| dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present) | |
| if dim is _not_present or (single_dim and not isinstance(dim, Dim)): | |
| with _enable_layers(self.dims): | |
| print(f"dim fallback batch_tensor for {orig}") | |
| return Tensor.from_batched( | |
| orig(self._batchtensor, *args, **kwargs), self._has_device | |
| ) | |
| keepdim = ( | |
| _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False | |
| ) | |
| t, levels = self._tensor, llist(self._levels) | |
| dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim) | |
| dim_indices = tuple(levels.index(d) for d in dims) | |
| if reduce and not keepdim: | |
| new_levels = [l for i, l in enumerate(levels) if i not in dim_indices] | |
| else: | |
| new_levels = levels | |
| if len(dim_indices) == 1: | |
| dim_indices = dim_indices[ | |
| 0 | |
| ] # so that dims that really only take a single argument work... | |
| args = list(args) | |
| _patcharg(dim_name, dim_offset, args, kwargs, dim_indices) | |
| def wrap(t): | |
| if isinstance(t, TensorLike): | |
| return Tensor.from_positional(t, new_levels, self._has_device) | |
| return t | |
| with _enable_layers(new_levels): | |
| print(f"dim used batch_tensor for {orig}") | |
| r = orig(t, *args, **kwargs) | |
| return tree_map(wrap, r) | |
| return fn | |
| def _def(name, *args, **kwargs): | |
| from . import _Tensor | |
| orig = getattr(torch.Tensor, name) | |
| setattr(_Tensor, name, _wrap(orig, *args, **kwargs)) | |
| no_slice = slice(None) | |
| _orig_getitem = torch.Tensor.__getitem__ | |
| class dim_tracker: | |
| def __init__(self): | |
| self.dims = llist() | |
| self.count = [] | |
| def record(self, d): | |
| if d not in self.dims: | |
| self.dims.append(d) | |
| self.count.append(1) | |
| def __getitem__(self, d): | |
| return self.count[self.dims.index(d)] | |
| def t__getitem__(self, input): | |
| from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike | |
| # * bail to original example if we have a single non-Dim tensor, or a non-tensor | |
| # * locate ... or an unbound tensor list, and determine its size, bind dim list | |
| # (remember that None does not count to the total dim count) | |
| # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim, | |
| # produce the re-view if needed | |
| # * for each single-use dim index, replace with no_slice and mark that it will be added | |
| # (keep track of whether we have to call super) | |
| # * call super if needed | |
| # * if we have dims to bind, bind them (it will help if we eliminated ... and None before) | |
| # this handles bool indexing handling, as well as some other simple cases. | |
| is_simple = ( | |
| not isinstance(input, Dim) | |
| and not isinstance(input, (tuple, list)) | |
| and | |
| # WAR for functorch bug where zero time tensors in getitem are not handled correctly. | |
| not (isinstance(input, TensorLike) and input.ndim == 0) | |
| ) | |
| if is_simple: | |
| if isinstance(self, _Tensor): | |
| return _Tensor.__torch_function__(_orig_getitem, None, (self, input)) | |
| else: | |
| return _orig_getitem(self, input) | |
| # can further optimize this case | |
| if not isinstance(input, tuple): | |
| input = [input] | |
| else: | |
| input = list(input) | |
| dims_indexed = 0 | |
| expanding_object = None | |
| dimlists = [] | |
| for i, s in enumerate(input): | |
| if s is ... or isinstance(s, DimList) and not s.is_bound: | |
| if expanding_object is not None: | |
| msg = ( | |
| "at most one ... or unbound dimension list can exist in indexing list but" | |
| f" found 2 at offsets {i} and {expanding_object}" | |
| ) | |
| raise DimensionBindError(msg) | |
| expanding_object = i | |
| if isinstance(s, DimList): | |
| dims_indexed += len(s) if s.is_bound else 0 | |
| dimlists.append(i) | |
| elif s is not None and s is not ...: | |
| dims_indexed += 1 | |
| ndim = self.ndim | |
| if dims_indexed > ndim: | |
| raise IndexError( | |
| f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions." | |
| ) | |
| if expanding_object is not None: | |
| expanding_ndims = ndim - dims_indexed | |
| obj = input[expanding_object] | |
| if obj is ...: | |
| input[expanding_object : expanding_object + 1] = [ | |
| no_slice | |
| ] * expanding_ndims | |
| else: | |
| obj.bind_len(expanding_ndims) | |
| # flatten the dimslists into the indexing | |
| for i in reversed(dimlists): | |
| input[i : i + 1] = input[i] | |
| dims_indexed = 0 | |
| requires_view = False | |
| size = self.size() | |
| view_sizes = [] | |
| dims_seen = dim_tracker() | |
| def add_dims(t): | |
| if not isinstance(t, _Tensor): | |
| return | |
| for d in t.dims: | |
| dims_seen.record(d) | |
| add_dims(self) | |
| dim_packs = [] | |
| for i, idx in enumerate(input): | |
| if idx is None: | |
| input[i] = no_slice | |
| view_sizes.append(1) | |
| requires_view = True | |
| else: | |
| sz = size[dims_indexed] | |
| if isinstance(idx, Dim): | |
| idx.size = sz | |
| dims_seen.record(idx) | |
| view_sizes.append(sz) | |
| elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim): | |
| for d in idx: | |
| dims_seen.record(idx) | |
| _bind_dims_to_size(sz, idx, f"offset {i}") | |
| view_sizes.extend(d.size for d in idx) | |
| requires_view = True | |
| dim_packs.append(i) | |
| else: | |
| add_dims(idx) | |
| view_sizes.append(sz) | |
| dims_indexed += 1 | |
| if requires_view: | |
| self = self.view(*view_sizes) | |
| for i in reversed(dim_packs): | |
| input[i : i + 1] = input[i] | |
| # currenty: | |
| # input is flat, containing either Dim, or Tensor, or something valid for standard indexing | |
| # self may have first-class dims as well. | |
| # to index: | |
| # drop the first class dims from self, they just become direct indices of their positions | |
| # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index. | |
| # these dimensions will appear and need to be bound at the first place tensor occures | |
| if isinstance(self, _Tensor): | |
| ptensor_self, levels = self._tensor, list(self._levels) | |
| # indices to ptensor rather than self which has first-class dimensions | |
| input_it = iter(input) | |
| flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels] | |
| has_device = self._has_device | |
| to_pad = 0 | |
| else: | |
| ptensor_self, flat_inputs = self, input | |
| to_pad = ptensor_self.ndim - len(flat_inputs) | |
| has_device = True | |
| result_levels = [] | |
| index_levels = [] | |
| tensor_insert_point = None | |
| to_expand = {} | |
| requires_getindex = False | |
| for i, inp in enumerate(flat_inputs): | |
| if isinstance(inp, Dim) and dims_seen[inp] == 1: | |
| flat_inputs[i] = no_slice | |
| result_levels.append(inp) | |
| elif isinstance(inp, TensorLike): | |
| requires_getindex = True | |
| if tensor_insert_point is None: | |
| tensor_insert_point = len(result_levels) | |
| ptensor, levels, _ = _tensor_levels(inp) | |
| to_expand[i] = levels | |
| flat_inputs[i] = ptensor | |
| for l in levels: | |
| if l not in index_levels: | |
| index_levels.append(l) | |
| else: | |
| requires_getindex = True | |
| result_levels.append(0) | |
| if tensor_insert_point is not None: | |
| result_levels[tensor_insert_point:tensor_insert_point] = index_levels | |
| for i, levels in to_expand.items(): | |
| flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels) | |
| if requires_getindex: | |
| result = _orig_getitem(ptensor_self, flat_inputs) | |
| else: | |
| result = ptensor_self | |
| next_positional = -1 | |
| if to_pad > 0: | |
| result_levels.extend([0] * to_pad) | |
| for i, r in enumerate(reversed(result_levels)): | |
| if isinstance(r, int): | |
| result_levels[-1 - i] = next_positional | |
| next_positional -= 1 | |
| return Tensor.from_positional(result, result_levels, has_device) | |
| # XXX - dim is optional and can be the outer-most dimension... | |
| def stack(tensors, new_dim, dim=0, out=None): | |
| if isinstance(dim, int): | |
| return torch.stack(tensors, dim, out).index(dim, new_dim) | |
| index = None | |
| if out is not None: | |
| out, index = _positional_no_permute(out, dim, expand_dim=True) | |
| ptensors = [] | |
| for t in tensors: | |
| pt, pi = _positional_no_permute(t, dim, expand_dim=True) | |
| if index is not None and pi != index: | |
| pt = pt.move_dim(pi, index) | |
| else: | |
| index = pi | |
| ptensors.append(pt) | |
| pr = torch.stack(ptensors, index, out=out) | |
| return pr.index((index, index + 1), (new_dim, dim)) | |
| _orig_split = torch.Tensor.split | |
| def split(self, split_size_or_sections, dim=0): | |
| from . import _Tensor, Dim | |
| if isinstance(split_size_or_sections, int) or any( | |
| isinstance(t, int) for t in split_size_or_sections | |
| ): | |
| if isinstance(dim, Dim): | |
| raise ValueError( | |
| "when dim is specified as a Dim object, split sizes must also be dimensions." | |
| ) | |
| return _orig_split(self, split_size_or_sections, dim=dim) | |
| if isinstance(dim, Dim): | |
| assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}" | |
| self, dim = _positional_no_permute(self, dim) | |
| size = self.size(dim) | |
| total_bound_size = 0 | |
| unbound = [] | |
| sizes = [] | |
| for i, d in enumerate(split_size_or_sections): | |
| if d.is_bound: | |
| sizes.append(d.size) | |
| total_bound_size += d.size | |
| else: | |
| sizes.append(0) | |
| unbound.append(i) | |
| if unbound: | |
| assert ( | |
| total_bound_size <= size | |
| ), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" | |
| remaining_size = size - total_bound_size | |
| chunk_size = -(-remaining_size // len(unbound)) | |
| for u in unbound: | |
| sz = min(chunk_size, remaining_size) | |
| split_size_or_sections[u].size = sz | |
| sizes[u] = sz | |
| remaining_size -= sz | |
| else: | |
| assert ( | |
| total_bound_size == size | |
| ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" | |
| return tuple( | |
| t.index(dim, d) | |
| for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)) | |
| ) | |