Spaces:
Paused
Paused
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| # pointwise operators can go through a faster pathway | |
| tensor_magic_methods = [ | |
| 'add', | |
| '' | |
| ] | |
| pointwise_magic_methods_with_reverse = ( | |
| 'add', 'sub', 'mul', 'floordiv', 'div', 'truediv', 'mod', | |
| 'pow', 'lshift', 'rshift', 'and', 'or', 'xor' | |
| ) | |
| pointwise_magic_methods = ( | |
| *(x for m in pointwise_magic_methods_with_reverse for x in (m, 'r' + m)), | |
| 'eq', 'gt', 'le', 'lt', 'ge', 'gt', 'ne', 'neg', 'pos', | |
| 'abs', 'invert', | |
| 'iadd', 'isub', 'imul', 'ifloordiv', 'idiv', | |
| 'itruediv', 'imod', 'ipow', 'ilshift', 'irshift', 'iand', | |
| 'ior', 'ixor', | |
| 'int', 'long', 'float', 'complex', | |
| ) | |
| pointwise_methods = ( | |
| *(f'__{m}__' for m in pointwise_magic_methods), | |
| ) | |
| pointwise = ( | |
| *(getattr(torch.Tensor, m) for m in pointwise_methods), | |
| torch.nn.functional.dropout, | |
| torch.where, | |
| torch.Tensor.abs, | |
| torch.abs, | |
| torch.Tensor.acos, | |
| torch.acos, | |
| torch.Tensor.acosh, | |
| torch.acosh, | |
| torch.Tensor.add, | |
| torch.add, | |
| torch.Tensor.addcdiv, | |
| torch.addcdiv, | |
| torch.Tensor.addcmul, | |
| torch.addcmul, | |
| torch.Tensor.addr, | |
| torch.addr, | |
| torch.Tensor.angle, | |
| torch.angle, | |
| torch.Tensor.asin, | |
| torch.asin, | |
| torch.Tensor.asinh, | |
| torch.asinh, | |
| torch.Tensor.atan, | |
| torch.atan, | |
| torch.Tensor.atan2, | |
| torch.atan2, | |
| torch.Tensor.atanh, | |
| torch.atanh, | |
| torch.Tensor.bitwise_and, | |
| torch.bitwise_and, | |
| torch.Tensor.bitwise_left_shift, | |
| torch.bitwise_left_shift, | |
| torch.Tensor.bitwise_not, | |
| torch.bitwise_not, | |
| torch.Tensor.bitwise_or, | |
| torch.bitwise_or, | |
| torch.Tensor.bitwise_right_shift, | |
| torch.bitwise_right_shift, | |
| torch.Tensor.bitwise_xor, | |
| torch.bitwise_xor, | |
| torch.Tensor.ceil, | |
| torch.ceil, | |
| torch.celu, | |
| torch.nn.functional.celu, | |
| torch.Tensor.clamp, | |
| torch.clamp, | |
| torch.Tensor.clamp_max, | |
| torch.clamp_max, | |
| torch.Tensor.clamp_min, | |
| torch.clamp_min, | |
| torch.Tensor.copysign, | |
| torch.copysign, | |
| torch.Tensor.cos, | |
| torch.cos, | |
| torch.Tensor.cosh, | |
| torch.cosh, | |
| torch.Tensor.deg2rad, | |
| torch.deg2rad, | |
| torch.Tensor.digamma, | |
| torch.digamma, | |
| torch.Tensor.div, | |
| torch.div, | |
| torch.dropout, | |
| torch.nn.functional.dropout, | |
| torch.nn.functional.elu, | |
| torch.Tensor.eq, | |
| torch.eq, | |
| torch.Tensor.erf, | |
| torch.erf, | |
| torch.Tensor.erfc, | |
| torch.erfc, | |
| torch.Tensor.erfinv, | |
| torch.erfinv, | |
| torch.Tensor.exp, | |
| torch.exp, | |
| torch.Tensor.exp2, | |
| torch.exp2, | |
| torch.Tensor.expm1, | |
| torch.expm1, | |
| torch.feature_dropout, | |
| torch.Tensor.float_power, | |
| torch.float_power, | |
| torch.Tensor.floor, | |
| torch.floor, | |
| torch.Tensor.floor_divide, | |
| torch.floor_divide, | |
| torch.Tensor.fmod, | |
| torch.fmod, | |
| torch.Tensor.frac, | |
| torch.frac, | |
| torch.Tensor.frexp, | |
| torch.frexp, | |
| torch.Tensor.gcd, | |
| torch.gcd, | |
| torch.Tensor.ge, | |
| torch.ge, | |
| torch.nn.functional.gelu, | |
| torch.nn.functional.glu, | |
| torch.Tensor.gt, | |
| torch.gt, | |
| torch.Tensor.hardshrink, | |
| torch.hardshrink, | |
| torch.nn.functional.hardshrink, | |
| torch.nn.functional.hardsigmoid, | |
| torch.nn.functional.hardswish, | |
| torch.nn.functional.hardtanh, | |
| torch.Tensor.heaviside, | |
| torch.heaviside, | |
| torch.Tensor.hypot, | |
| torch.hypot, | |
| torch.Tensor.i0, | |
| torch.i0, | |
| torch.Tensor.igamma, | |
| torch.igamma, | |
| torch.Tensor.igammac, | |
| torch.igammac, | |
| torch.Tensor.isclose, | |
| torch.isclose, | |
| torch.Tensor.isfinite, | |
| torch.isfinite, | |
| torch.Tensor.isinf, | |
| torch.isinf, | |
| torch.Tensor.isnan, | |
| torch.isnan, | |
| torch.Tensor.isneginf, | |
| torch.isneginf, | |
| torch.Tensor.isposinf, | |
| torch.isposinf, | |
| torch.Tensor.isreal, | |
| torch.isreal, | |
| torch.Tensor.kron, | |
| torch.kron, | |
| torch.Tensor.lcm, | |
| torch.lcm, | |
| torch.Tensor.ldexp, | |
| torch.ldexp, | |
| torch.Tensor.le, | |
| torch.le, | |
| torch.nn.functional.leaky_relu, | |
| torch.Tensor.lerp, | |
| torch.lerp, | |
| torch.Tensor.lgamma, | |
| torch.lgamma, | |
| torch.Tensor.log, | |
| torch.log, | |
| torch.Tensor.log10, | |
| torch.log10, | |
| torch.Tensor.log1p, | |
| torch.log1p, | |
| torch.Tensor.log2, | |
| torch.log2, | |
| torch.nn.functional.logsigmoid, | |
| torch.Tensor.logical_and, | |
| torch.logical_and, | |
| torch.Tensor.logical_not, | |
| torch.logical_not, | |
| torch.Tensor.logical_or, | |
| torch.logical_or, | |
| torch.Tensor.logical_xor, | |
| torch.logical_xor, | |
| torch.Tensor.logit, | |
| torch.logit, | |
| torch.Tensor.lt, | |
| torch.lt, | |
| torch.Tensor.maximum, | |
| torch.maximum, | |
| torch.Tensor.minimum, | |
| torch.minimum, | |
| torch.nn.functional.mish, | |
| torch.Tensor.mvlgamma, | |
| torch.mvlgamma, | |
| torch.Tensor.nan_to_num, | |
| torch.nan_to_num, | |
| torch.Tensor.ne, | |
| torch.ne, | |
| torch.Tensor.neg, | |
| torch.neg, | |
| torch.Tensor.nextafter, | |
| torch.nextafter, | |
| torch.Tensor.outer, | |
| torch.outer, | |
| torch.polar, | |
| torch.Tensor.polygamma, | |
| torch.polygamma, | |
| torch.Tensor.positive, | |
| torch.positive, | |
| torch.Tensor.pow, | |
| torch.pow, | |
| torch.Tensor.prelu, | |
| torch.prelu, | |
| torch.nn.functional.prelu, | |
| torch.Tensor.rad2deg, | |
| torch.rad2deg, | |
| torch.Tensor.reciprocal, | |
| torch.reciprocal, | |
| torch.Tensor.relu, | |
| torch.relu, | |
| torch.nn.functional.relu, | |
| torch.nn.functional.relu6, | |
| torch.Tensor.remainder, | |
| torch.remainder, | |
| torch.Tensor.round, | |
| torch.round, | |
| torch.rrelu, | |
| torch.nn.functional.rrelu, | |
| torch.Tensor.rsqrt, | |
| torch.rsqrt, | |
| torch.rsub, | |
| torch.selu, | |
| torch.nn.functional.selu, | |
| torch.Tensor.sgn, | |
| torch.sgn, | |
| torch.Tensor.sigmoid, | |
| torch.sigmoid, | |
| torch.nn.functional.sigmoid, | |
| torch.Tensor.sign, | |
| torch.sign, | |
| torch.Tensor.signbit, | |
| torch.signbit, | |
| torch.nn.functional.silu, | |
| torch.Tensor.sin, | |
| torch.sin, | |
| torch.Tensor.sinc, | |
| torch.sinc, | |
| torch.Tensor.sinh, | |
| torch.sinh, | |
| torch.nn.functional.softplus, | |
| torch.nn.functional.softshrink, | |
| torch.Tensor.sqrt, | |
| torch.sqrt, | |
| torch.Tensor.square, | |
| torch.square, | |
| torch.Tensor.sub, | |
| torch.sub, | |
| torch.Tensor.tan, | |
| torch.tan, | |
| torch.Tensor.tanh, | |
| torch.tanh, | |
| torch.nn.functional.tanh, | |
| torch.threshold, | |
| torch.nn.functional.threshold, | |
| torch.trapz, | |
| torch.Tensor.true_divide, | |
| torch.true_divide, | |
| torch.Tensor.trunc, | |
| torch.trunc, | |
| torch.Tensor.xlogy, | |
| torch.xlogy, | |
| torch.rand_like, | |
| ) | |