Spaces:
Sleeping
Sleeping
import torch | |
from torch import Tensor | |
aten = torch.ops.aten | |
import inspect | |
import warnings | |
from typing import Dict, List, Optional, Set | |
from torch.types import Number | |
decomposition_table: Dict[str, torch.jit.ScriptFunction] = {} | |
function_name_set: Set[str] = set() | |
def check_decomposition_has_type_annotations(f): | |
inspect_empty = inspect._empty # type: ignore[attr-defined] | |
sig = inspect.signature(f) | |
for param in sig.parameters.values(): | |
assert ( | |
param.annotation != inspect_empty | |
), f"No signature on param {param.name} for function {f.name}" | |
assert ( | |
sig.return_annotation != inspect_empty | |
), f"No return annotation for function {f.name}" | |
def signatures_match(decomposition_sig, torch_op_sig): | |
decomp_params = decomposition_sig.parameters | |
op_params = torch_op_sig.parameters | |
if len(decomp_params) != len(op_params): | |
return False | |
for decomp_param, op_param in zip(decomp_params.values(), op_params.values()): | |
# can't check full equality yet because not all fields are correcly deduced | |
# in the torch_op_sig - like default value | |
# can't check 'kind' bc | |
# kwarg-only values with defaults not yet supported in TS | |
inspect_empty = inspect._empty # type: ignore[attr-defined] | |
for field in ["name", "annotation"]: | |
if field == "name" and decomp_param.name == "self": | |
warnings.warn("PyTorch uses 'input' instead of 'self' on public api") | |
if getattr(decomp_param, field) != getattr(op_param, field): | |
return False | |
decomp_default = decomp_param.default | |
op_default = op_param.default | |
# default value not always correctly inferred as being present on torch schema, | |
# but if specified on both they should be equal | |
if decomp_default != inspect_empty and op_default != inspect_empty: | |
if decomp_default != op_default: | |
return False | |
return decomposition_sig.return_annotation == torch_op_sig.return_annotation | |
def register_decomposition(aten_op, registry=None): | |
def decomposition_decorator(f): | |
nonlocal registry | |
if registry is None: | |
registry = decomposition_table | |
assert isinstance(aten_op, torch._ops.OpOverload) | |
# Need unique name for jit function serialization | |
assert ( | |
f.__name__ not in function_name_set | |
), f"Duplicated function name {f.__name__}" | |
function_name_set.add(f.__name__) | |
scripted_func = torch.jit.script(f) | |
torch._C._jit_pass_inline(scripted_func.graph) | |
for _ in range(2): | |
torch._C._jit_pass_peephole(scripted_func.graph) | |
torch._C._jit_pass_constant_propagation(scripted_func.graph) | |
registry[str(aten_op._schema)] = scripted_func | |
return f | |
return decomposition_decorator | |
# TODO: replace torch.sigmoid -> aten.sigmoid | |
def var_decomposition( | |
input: Tensor, | |
dim: Optional[List[int]] = None, | |
correction: Optional[Number] = None, | |
keepdim: bool = False, | |
) -> Tensor: | |
if dim is None: | |
dim_i: List[int] = [] | |
dim = dim_i | |
if isinstance(dim, (tuple, list)) and len(dim) == 0: | |
n = input.numel() | |
else: | |
n = 1 | |
for dim_i in dim: # type: ignore[assignment] | |
n *= input.shape[dim_i] # type: ignore[call-overload] | |
mean = aten.mean(input, dim, True) | |
sub = input - mean | |
sq = sub * sub | |
sum = aten.sum(sq, dim, keepdim) | |
if correction is None: | |
denom = float(n - 1) | |
else: | |
if isinstance(correction, int): | |
denom = float(n - correction) | |
elif isinstance(correction, float): | |
denom = float(n) - correction | |
else: | |
raise RuntimeError("correction must be int or float") | |
return sum / max(0, denom) | |
def var(input: Tensor, unbiased: bool = True) -> Tensor: | |
return var_decomposition(input, correction=(1 if unbiased else 0)) | |