Spaces:
Running
Running
import inspect | |
import textwrap | |
import torch.jit | |
from torch.jit._builtins import _find_builtin | |
# this file is for generating documentation using sphinx autodoc | |
# > help(torch.jit.supported_ops) will also give a nice listed of the | |
# supported ops programmatically | |
def _hidden(name): | |
return name.startswith("_") and not name.startswith("__") | |
def _emit_type(type): | |
return str(type) | |
def _emit_arg(indent, i, arg): | |
v = f"{arg.name} : {_emit_type(arg.type)}" | |
default = arg.default_value | |
if default is not None: | |
v = f"{v}={str(default)}" | |
if i > 0: | |
v = f"\n{' ' * indent}{v}" | |
return v | |
def _emit_args(indent, arguments): | |
return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments)) | |
def _emit_ret(ret): | |
return _emit_type(ret.type) | |
def _emit_rets(returns): | |
if len(returns) == 1: | |
return _emit_ret(returns[0]) | |
return f"Tuple[{', '.join(_emit_ret(r) for r in returns)}]" | |
def _emit_schema(mod, name, schema, arg_start=0, padding=4): | |
if mod is None: | |
qualified_name = name | |
else: | |
qualified_name = f"{mod}.{name}" | |
schema_str = "{}({}) -> {}".format( | |
qualified_name, | |
_emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:]), | |
_emit_rets(schema.returns), | |
) | |
return schema_str | |
def _get_tensor_ops(): | |
def is_tensor_method(schema): | |
if len(schema.arguments) == 0: | |
return False | |
self = schema.arguments[0] | |
if self.name != "self": | |
return False | |
if not self.type.isSubtypeOf(torch._C.TensorType.get()): | |
return False | |
return True | |
methods = [] | |
# discover methods | |
for elem in dir(torch.Tensor): | |
if not _hidden(elem): | |
schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem) | |
for schema in schemas: | |
if is_tensor_method(schema): | |
methods.append(_emit_schema("Tensor", elem, schema, arg_start=1)) | |
return "Supported Tensor Methods", methods | |
def _get_nn_functional_ops(): | |
functions = [] | |
# Iterate over torch.nn.functional | |
mod = torch.nn.functional | |
name = mod.__name__ | |
for elem in dir(torch.nn.functional): | |
attr = getattr(mod, elem) | |
if not inspect.isfunction(attr) or _hidden(elem[0]): | |
# Ignore non-functions and internal methods | |
continue | |
attr_module = inspect.getmodule(attr) | |
if not attr_module: | |
raise RuntimeError(f"Module for {attr} not found") | |
if "torch.nn.functional" not in attr_module.__name__: | |
# Ignore functions from outside torch.nn.functional | |
continue | |
try: | |
# compile fn, get schema | |
scripted = torch.jit.script(attr) | |
scripted_schema = scripted.schema | |
functions.append(_emit_schema(name, elem, scripted_schema)) | |
except: # noqa: B001,E722 | |
# Skip interpolate / boolean dispatched things | |
pass | |
# Iterate over modules that we know contain a lot of builtins | |
for mod in torch.jit._builtins._modules_containing_builtins: | |
name = mod.__name__ | |
for elem in dir(mod): | |
builtin = _find_builtin(getattr(mod, elem)) | |
if builtin is not None: | |
schemas = torch._C._jit_get_schemas_for_operator(builtin) | |
for schema in schemas: | |
# remove _tan but not __and__ | |
if not _hidden(elem): | |
functions.append(_emit_schema(name, elem, schema)) | |
return "Supported PyTorch Functions", functions | |
def _get_builtins_helper(): | |
builtins = [] | |
for fn, _builtin_name in torch.jit._builtins._builtin_ops: | |
mod = inspect.getmodule(fn) | |
if not hasattr(fn, "__name__"): | |
# typing classes | |
continue | |
if not mod: | |
continue | |
if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__): | |
# skip internal-only methods | |
continue | |
if "torch._C" in mod.__name__: | |
continue | |
builtins.append((fn, _builtin_name)) | |
return builtins | |
def _is_math_fn(fn): | |
mod = inspect.getmodule(fn) | |
if not mod: | |
raise RuntimeError(f"Module for {fn} not found") | |
return mod.__name__ == "math" | |
def _get_torchscript_builtins(): | |
functions = [] | |
builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper()) | |
builtins_list = list(builtins) | |
# Iterate over the specially added builtins | |
for fn, _builtin_name in builtins_list: | |
mod = inspect.getmodule(fn) | |
if not mod: | |
raise RuntimeError(f"Module for {fn} not found") | |
builtin = _find_builtin(fn) | |
if builtin is not None: | |
schemas = torch._C._jit_get_schemas_for_operator(builtin) | |
for schema in schemas: | |
functions.append(_emit_schema(mod.__name__, fn.__name__, schema)) | |
pass | |
return "TorchScript Builtin Functions", functions | |
def _get_math_builtins(): | |
functions = [] | |
builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper()) | |
builtins_list = list(builtins) | |
# Iterate over the specially added builtins | |
for fn, _builtin_name in builtins_list: | |
mod = inspect.getmodule(fn) | |
if not mod: | |
raise RuntimeError(f"Module for {fn} not found") | |
builtin = _find_builtin(fn) | |
if builtin is not None: | |
schemas = torch._C._jit_get_schemas_for_operator(builtin) | |
for schema in schemas: | |
schema_str = _emit_schema(mod.__name__, fn.__name__, schema) | |
if "Tensor" in schema_str: | |
# Skip Tensor ops that have the same name as math functions | |
# (they will show up in the tensor methods section) | |
continue | |
functions.append(schema) | |
pass | |
return "``math`` Module", functions | |
def _get_global_builtins(): | |
# Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp | |
supported_builtins = [ | |
"print", | |
"tuple", | |
"float", | |
"complex", | |
"int", | |
"bool", | |
"str", | |
"getattr", | |
"hasattr", | |
"isinstance", | |
"len", | |
"hex", | |
"oct", | |
"round", | |
"hash", | |
"min", | |
"max", | |
"abs", | |
"all", | |
"divmod", | |
"list", | |
"ord", | |
"chr", | |
"bin", | |
"range", | |
"zip", | |
"enumerate", | |
"sorted", | |
] | |
op_renames = { | |
"bool": "aten::Bool", | |
"int": "aten::Int", | |
"float": "aten::Float", | |
"complex": "aten::Complex", | |
"abs": "prim::abs", | |
"max": "prim::max", | |
"min": "prim::min", | |
"range": "fake::does_not_exist", | |
} | |
schemaless_op_explanations = { | |
"print": "Print any value", | |
"tuple": "Lists cannot be converted to tuples with this method since their size is not statically known", | |
"getattr": "Attribute name must be a literal string", | |
"hasattr": "Attribute name must be a literal string", | |
"isinstance": "Result is static", | |
"zip": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.", | |
"enumerate": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.", | |
"range": "Can only be used as an iterator in a for loop", | |
} | |
magic_methods = [ | |
("complex", "__complex__"), | |
("float", "__float__"), | |
("int", "__int__"), | |
("bool", "__bool__"), | |
("str", "__str__"), | |
("len", "__len__"), | |
("hex", "__hex__"), | |
("oct", "__oct__"), | |
] | |
magic_methods_rows = [] | |
for fn, magic_method in magic_methods: | |
magic_methods_rows.append(f'"{fn}", "``{magic_method}``"') | |
schematized_ops = [] | |
schemaless_ops = [] | |
for fn in supported_builtins: | |
op_name = f"aten::{fn}" | |
if fn in op_renames: | |
op_name = op_renames[fn] | |
schemas = torch._C._jit_get_schemas_for_operator(op_name) | |
for s in schemas: | |
schematized_ops.append(_emit_schema(None, fn, s, padding=0)) | |
if len(schemas) > 0: | |
schematized_ops.append("") | |
else: | |
table_row = f'":any:`{fn}`", "{schemaless_op_explanations[fn]}"' | |
schemaless_ops.append(table_row) | |
schematized_ops_str = "\n".join(schematized_ops) | |
schemaless_ops_str = "\n".join(schemaless_ops) | |
magic_methods_rows_str = "\n".join(magic_methods_rows) | |
schematized_ops_str = textwrap.indent(schematized_ops_str, "\t") | |
schemaless_ops_str = textwrap.indent(schemaless_ops_str, "\t") | |
magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, "\t") | |
section = f""" | |
The functions in the following table are supported but do not have a static schema | |
.. csv-table:: | |
:header: "Function", "Note" | |
{schemaless_ops_str} | |
The following functions will use the corresponding magic method on :any:`TorchScript classes` | |
.. csv-table:: | |
:header: "Function", "Magic Method" | |
{magic_methods_rows_str} | |
These built-in functions use the schema | |
.. rst-class:: codeblock-height-limiter | |
:: | |
{schematized_ops_str} | |
""" | |
return "Python Built-in Functions", section | |
def _list_supported_ops(): | |
def emit_block(decls): | |
return "\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n".format( | |
"".join(f" {d}\n\n" for d in decls) | |
) | |
body = "" | |
op_gathering_fns = ( | |
_get_tensor_ops, | |
_get_nn_functional_ops, | |
_get_torchscript_builtins, | |
_get_global_builtins, | |
_get_math_builtins, | |
) | |
for fn in op_gathering_fns: | |
header, items = fn() | |
link_target = header.replace("`", "").replace("-", "").lower().replace(" ", "-") | |
if isinstance(items, str): | |
section = f"{header}\n{'~' * len(header)}\n{items}\n" | |
else: | |
section = f"{header}\n{'~' * len(header)}\n{emit_block(items)}" | |
section = f".. _{link_target}:" + "\n\n" + section | |
body += section | |
return body | |
__doc__ = _list_supported_ops() | |