Spaces:
Running
Running
import dataclasses | |
import inspect | |
import sys | |
from typing import Any, Callable, Tuple | |
import torch | |
class Kernel: | |
"""Models a (function, source location)""" | |
func: Callable | |
source: str | |
def __call__(self, *args, **kwargs): | |
return self.func(*args, **kwargs) | |
class RegistrationHandle: | |
"""Does something when someone calls .destroy() on it""" | |
def __init__(self, on_destroy: Callable): | |
self._on_destroy = on_destroy | |
def destroy(self) -> None: | |
self._on_destroy() | |
def get_source(stacklevel: int) -> str: | |
"""Get a string that represents the caller. | |
Example: "/path/to/foo.py:42" | |
Use stacklevel=1 to get the caller's source | |
Use stacklevel=2 to get the caller's caller's source | |
etc. | |
""" | |
frame = inspect.getframeinfo(sys._getframe(stacklevel)) | |
source = f"{frame.filename}:{frame.lineno}" | |
return source | |
def parse_namespace(qualname: str) -> Tuple[str, str]: | |
splits = qualname.split("::") | |
if len(splits) != 2: | |
raise ValueError( | |
f"Expected `qualname` to be of the form " | |
f'"namespace::name", but got {qualname}. ' | |
f"The qualname passed to the torch.library APIs must consist " | |
f"of a namespace and a name, e.g. aten::sin" | |
) | |
return splits[0], splits[1] | |
def lookup_op(qualname: str) -> torch._ops.OpOverloadPacket: | |
namespace, name = parse_namespace(qualname) | |
if "." in name: | |
name, overload = name.split(".") | |
else: | |
overload = "default" | |
ns = getattr(torch.ops, namespace) | |
packet = getattr(ns, name) | |
return getattr(packet, overload) | |
def is_builtin(op: torch._ops.OpOverload) -> bool: | |
assert isinstance(op, torch._ops.OpOverload) | |
return op.namespace in {"aten", "prim", "prims"} | |
def is_functional_schema(schema: Any) -> bool: | |
"""Check if the schema is functional. | |
An operator is functional if: | |
- it does not mutate any of its inputs | |
- it does not return a view on any of its inputs | |
- it has at least one return | |
""" | |
# Lazy import because not all PyTorch builds have torchgen | |
from torchgen.model import FunctionSchema, SchemaKind | |
assert isinstance(schema, (str, FunctionSchema)) | |
if isinstance(schema, str): | |
schema = FunctionSchema.parse(schema) | |
if schema.kind() != SchemaKind.functional: | |
return False | |
rets = schema.returns | |
is_non_mutating_view = len(rets) > 0 and any( | |
r.annotation is not None and not r.annotation.is_write for r in rets | |
) | |
if is_non_mutating_view: | |
return False | |
if not schema.returns: | |
return False | |
return True | |
def mutates_and_returns_first_arg(op: torch._ops.OpOverload): | |
"""Check if an op is an inplace aten op, i.e. it mutates and returns the first arg. | |
TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this, | |
but not all PyTorch builds have torchgen (due to the yaml dependency being weird). | |
Figure this out. | |
Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a) | |
""" | |
if op.namespace != "aten": | |
return False | |
schema = op._schema | |
if not len(schema.returns) == 1: | |
return False | |
if schema.returns[0].alias_info is None: | |
return False | |
alias_set = schema.returns[0].alias_info.after_set | |
if len(alias_set) != 1: | |
return False | |
loc = next(iter(alias_set)) | |
if len(schema.arguments) < 1: | |
return False | |
first_arg = schema.arguments[0] | |
if first_arg.alias_info is None: | |
return False | |
if not first_arg.alias_info.is_write: | |
return False | |
alias_set = first_arg.alias_info.after_set | |
if len(alias_set) != 1: | |
return False | |
if loc != next(iter(alias_set)): | |
return False | |
for arg in schema.arguments[1:]: | |
if arg.alias_info is not None: | |
return False | |
return True | |
def zip_schema(schema, args, kwargs): | |
"""zips schema.arguments and (args, kwargs) together. | |
Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload: | |
that is, kwargs must be keyword-only arguments and default values may be omitted. | |
""" | |
assert len(schema.arguments) >= len(args) + len(kwargs) | |
for i in range(len(schema.arguments)): | |
info = schema.arguments[i] | |
if info.kwarg_only: | |
if info.name in kwargs: | |
yield info, kwargs[info.name] | |
continue | |
if i >= len(args): | |
# args that are equal to their default values are not populated | |
# if they are followed by args that are equal to their defaults. | |
# Skip these. | |
continue | |
yield info, args[i] | |
return | |