Spaces:
Sleeping
Sleeping
""" | |
The weak_script annotation needs to be here instead of inside torch/jit/ so it | |
can be used in other places in torch/ (namely torch.nn) without running into | |
circular dependency problems | |
""" | |
import ast | |
import builtins | |
import collections | |
import contextlib | |
import enum | |
import inspect | |
import io | |
import pickle | |
import sys | |
import threading | |
import types | |
import typing | |
import warnings | |
import weakref | |
from textwrap import dedent | |
from typing import ( # noqa: F401 | |
Any, | |
Callable, | |
Dict, | |
Final, | |
ForwardRef, | |
Generic, | |
get_args, # new in 3.8 | |
get_origin, # new in 3.8 | |
List, | |
Optional, | |
Tuple, | |
Type, | |
TypeVar, | |
Union, | |
) | |
import torch | |
# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`. | |
# Explicitly ask to import `torch.distributed.__init__` first. | |
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised. | |
import torch.distributed.rpc | |
import torch.package._mangling as package_mangling | |
from torch._awaits import _Await | |
from torch._C import _Await as CAwait, Future as CFuture | |
from torch._sources import fake_range, get_source_lines_and_file, parse_def | |
from torch.futures import Future | |
IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9) | |
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10) | |
BuiltinUnionType: Union[Type, Tuple[Type, ...]] | |
if sys.version_info >= (3, 10): | |
# NOTE: IS_PY310_PLUS doesn't work with mypy. | |
# cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks | |
BuiltinUnionType = types.UnionType | |
else: | |
BuiltinUnionType = () # trick: this makes isinstance short circuit. | |
LockType: Type | |
try: | |
import _thread | |
LockType = _thread.LockType | |
except ImportError: | |
import _dummy_thread # type: ignore[import-not-found] | |
LockType = _dummy_thread.LockType | |
# Wrapper functions that can call either of 2 functions depending on a boolean | |
# argument | |
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = ( | |
weakref.WeakKeyDictionary() | |
) # noqa: T484 | |
FAKE_FILENAME_PREFIX = "__torch_jit_dataclass" | |
class SourceLoader: | |
def __init__(self): | |
self.content = {} | |
def cache(self, fn, source): | |
self.content[fn] = source | |
def get_source(self, fn): | |
return self.content.get(fn) | |
loader = SourceLoader() | |
def createResolutionCallbackFromEnv(lookup_base): | |
""" | |
Creates a resolution callback that will look up qualified names in an | |
environment, starting with `lookup_base` for the base of any qualified | |
names, then proceeding down the lookup chain with the resolved object. | |
You should not use this directly, it should only be used from the other | |
createResolutionCallbackFrom* functions. | |
""" | |
def lookupInModule(qualified_name, module): | |
if "." in qualified_name: | |
base, remaining_pieces = qualified_name.split(".", maxsplit=1) | |
module_value = getattr(module, base) | |
return lookupInModule(remaining_pieces, module_value) | |
else: | |
return getattr(module, qualified_name) | |
def parseNestedExpr(expr, module) -> Tuple[Any, int]: | |
i = 0 | |
while i < len(expr) and expr[i] not in (",", "[", "]"): | |
i += 1 | |
# Special case logic for the empty Tuple as a subscript (used | |
# in the type annotation `Tuple[()]`) | |
if expr[:i] == "()": | |
return (), i | |
base = lookupInModule(expr[:i].strip(), module) | |
assert base is not None, f"Unresolvable type {expr[:i]}" | |
if i == len(expr) or expr[i] != "[": | |
return base, i | |
assert expr[i] == "[" | |
parts = [] | |
while expr[i] != "]": | |
part_len = 0 | |
i += 1 | |
part, part_len = parseNestedExpr(expr[i:], module) | |
parts.append(part) | |
i += part_len | |
if len(parts) > 1: | |
return base[tuple(parts)], i + 1 | |
else: | |
return base[parts[0]], i + 1 | |
def parseExpr(expr, module): | |
try: | |
value, len_parsed = parseNestedExpr(expr, module) | |
assert len_parsed == len( | |
expr | |
), "whole expression was not parsed, falling back to c++ parser" | |
return value | |
except Exception: | |
""" | |
The python resolver fails in several cases in known unit tests, and is intended | |
to fall back gracefully to the c++ resolver in general. For example, python 2 style | |
annotations which are frequent in our unit tests often fail with types e.g. int not | |
resolvable from the calling frame. | |
""" | |
return None | |
return lambda expr: parseExpr(expr, lookup_base) | |
def createResolutionCallbackFromFrame(frames_up: int = 0): | |
""" | |
Creates a function which, given a string variable name, | |
returns the value of the variable in the scope of the caller of | |
the function which called createResolutionCallbackFromFrame (by default). | |
This is used to enable access in-scope Python variables inside | |
TorchScript fragments. | |
frames_up is number of additional frames to go up on the stack. | |
The default value is 0, which correspond to the frame of the caller | |
of createResolutionCallbackFromFrame. Also for example, if frames_up is set | |
to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame | |
will be taken. | |
For example, the following program prints 2:: | |
def bar(): | |
cb = createResolutionCallbackFromFrame(1) | |
print(cb("foo")) | |
def baz(): | |
foo = 2 | |
bar() | |
baz() | |
""" | |
frame = inspect.currentframe() | |
i = 0 | |
while i < frames_up + 1: | |
assert frame is not None | |
frame = frame.f_back | |
i += 1 | |
assert frame is not None | |
f_locals = frame.f_locals | |
f_globals = frame.f_globals | |
class env: | |
def __getattr__(self, key): | |
if key in f_locals: | |
return f_locals[key] | |
elif key in f_globals: | |
return f_globals[key] | |
elif key in dir(builtins): | |
return getattr(builtins, key) | |
return createResolutionCallbackFromEnv(env()) | |
def get_closure(fn): | |
""" | |
Get a dictionary of closed over variables from a function | |
""" | |
captures = {} | |
captures.update(fn.__globals__) | |
for index, captured_name in enumerate(fn.__code__.co_freevars): | |
captures[captured_name] = fn.__closure__[index].cell_contents | |
return captures | |
# [local resolution in python] | |
# Depending on where a variable is defined, and where it is used, we may | |
# or may not be able to recover its value when recursively compiling a | |
# script function. Remember in the general case, a module or function is | |
# first defined and then later scripted. This means we do not have a | |
# chance to capture the active frames when the function is defined. Hence any | |
# name resolution has to happen later on the created closure. The way | |
# python captures type annotations restricts what we can recover. The | |
# follow example illustrates the different cases: | |
# | |
# class MyGlobalClass: | |
# ... | |
# def my_local_scope(): | |
# @torch.jit.script | |
# class MyClass: | |
# ... | |
# @torch.jit.script | |
# class MyClassUsedAsVar: | |
# ... | |
# def eg(x: MyClass, y: MyGlobalClass): | |
# a_local_capture : Foo | |
# return MyClassUsedAsVar(x) | |
# | |
# MyGlobalClass is defined in the __globals__ dictionary of function | |
# 'eg', so it is always recoverable. my_local_scope introduces a new local | |
# variable scope in the function. Classes defined here are only visible as | |
# local variables. For the case of MyClassUsedAsVar, it is captured | |
# because it is used as a variable inside the body of the function, and we | |
# can resolve it using the captures returned from `get_closure`. However, | |
# the type annotations are not captured by the closure. In Python | |
# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as | |
# annotations on `eg``, but starting in Python 4.0, they will represented as | |
# strings and no longer present. Furthermore, since the body of `eg` does | |
# not reference those names, they do not appear in the list of closed over | |
# variables. In Python 2.x, type annotations are in comments, leading to a | |
# similar situation where their definitions are not available. We anticipate | |
# that most users will not run into this issue because their modules and | |
# functions will be defined at a global scope like MyGlobalClass. In cases | |
# where they are not, it is possible to work around issues by declaring the | |
# values global in the function. | |
# In Python 3.9 declaring class as global will make it invisible to | |
# `inspect.getsource`, see https://bugs.python.org/issue42666 . | |
# This could be worked around by manualy adding it to `global()` dictionary. | |
def createResolutionCallbackFromClosure(fn): | |
""" | |
Create a resolutionCallback by introspecting the function instead of | |
looking up the stack for the enclosing scope | |
""" | |
closure = get_closure(fn) | |
class closure_lookup: | |
# This is a class since `closure` is a dict and it's easier in | |
# `env_helper` if everything just works with `getattr` calls | |
def __getattr__(self, key): | |
if key in closure: | |
return closure[key] | |
elif hasattr(typing, key): | |
return getattr(typing, key) | |
elif hasattr(builtins, key): | |
return getattr(builtins, key) | |
return None | |
return createResolutionCallbackFromEnv(closure_lookup()) | |
def can_compile_class(cls) -> bool: | |
# If any of the functions on a type don't have a code object, this type can't | |
# be compiled and is probably a builtin / bound from C | |
if is_ignored_fn(cls): | |
return False | |
# Ignore the following list of built-in classes. | |
ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception) | |
if issubclass(cls, ignored_builtin_classes): | |
return False | |
names = cls.__dict__ | |
fns = [ | |
getattr(cls, name) | |
for name in names | |
if inspect.isroutine(getattr(cls, name, None)) | |
] | |
has_code = [hasattr(fn, "__code__") for fn in fns] | |
return all(has_code) | |
def get_callable_argument_names(fn) -> List[str]: | |
""" | |
Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`. | |
Returns an empty list when other types of arguments are present. | |
This is used by `torch.jit.trace` to assign meaningful argument names to | |
traced functions and modules. | |
Args: | |
fn: A callable. | |
Returns: | |
Argument names: List[str] | |
""" | |
# inspect.signature may fail, give up in that case. | |
try: | |
callable_signature = inspect.signature(fn) | |
except Exception: | |
return [] | |
argument_names = [] | |
for name, param in callable_signature.parameters.items(): | |
# All four other types of arguments do not map to individual values | |
# with a keyword as name. | |
if not param.kind == param.POSITIONAL_OR_KEYWORD: | |
continue | |
argument_names.append(name) | |
return argument_names | |
def get_annotation_str(annotation): | |
""" | |
Convert an AST node containing a type annotation to the string present in the source | |
that represents the same annotation. | |
""" | |
if isinstance(annotation, ast.Name): | |
return annotation.id | |
elif isinstance(annotation, ast.Attribute): | |
return ".".join([get_annotation_str(annotation.value), annotation.attr]) | |
elif isinstance(annotation, ast.Subscript): | |
# In Python3.9+ subscript indicies are not wrapped in ast.Index | |
subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value # type: ignore[attr-defined] | |
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]" | |
elif isinstance(annotation, ast.Tuple): | |
return ",".join([get_annotation_str(elt) for elt in annotation.elts]) | |
elif isinstance(annotation, (ast.Constant, ast.NameConstant)): | |
return f"{annotation.value}" | |
# If an AST node is not handled here, it's probably handled in ScriptTypeParser. | |
return None | |
def get_type_hint_captures(fn): | |
""" | |
Get a dictionary containing type resolution mappings necessary to resolve types | |
for the literal annotations on 'fn'. These are not considered to be closed-over by fn | |
and must be obtained separately (e.g. using this function). | |
Args: | |
fn: A callable. | |
Returns: | |
A Dict[str, Any] containing a mapping from the literal annotations used on | |
fn to the Python objects they refer to. | |
""" | |
# First, try to get the source of the function. We'll need to parse it to find the actual string names | |
# that were used to annotate the types, since inspect.signature() will only return the class object that | |
# the annotation refers to, not the string name. If we can't get the source, simply return an empty dict. | |
# This may happen in cases where the function is synthesized dynamically at runtime. | |
src = loader.get_source(fn) | |
if src is None: | |
src = inspect.getsource(fn) | |
# Gather a dictionary of parameter name -> type, skipping any parameters whose annotated | |
# types are strings. These are only understood by TorchScript in the context of a type annotation | |
# that refers to a class in its own definition, but trying to include a mapping for this in the result | |
# function would cause infinite recursion because the class is currently being compiled. | |
# In addition, there is logic in ScriptTypeParser to handle this. | |
signature = inspect.signature(fn) | |
name_to_type = { | |
name: parameter.annotation | |
for name, parameter in signature.parameters.items() | |
if parameter.annotation is not inspect.Parameter.empty | |
and not isinstance(parameter.annotation, str) | |
} | |
# Then, get the literal type annotations from the function declaration | |
# by source inspection. This accounts for the case in which aliases are used | |
# to annotate the arguments (e.g device_t = torch.device, and then d: device_t). | |
# frontend.py cannot be used here because it includes _jit_internal, so use ast instead. | |
a = ast.parse(dedent(src)) | |
if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef): | |
raise RuntimeError(f"Expected {fn} to be a function") | |
f = a.body[0] | |
# Prepare a dictionary of source annotation -> type, which will be the final result of this function, | |
# by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping | |
# them to the type object corresponding to the annotation via name_to_type using the parameter name. | |
annotation_to_type = {} | |
for arg in f.args.args: | |
# Get the source type annotation string for this argument if possible. | |
arg_annotation_str = ( | |
get_annotation_str(arg.annotation) if arg.annotation else None | |
) | |
# If the argument has no annotation or get_annotation_str cannot convert it to a string, | |
# arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle | |
# this in the latter case. | |
if arg_annotation_str is None: | |
continue | |
# Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not | |
# be present in name_to_type is that the annotation itself is a string and not a type object | |
# (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this. | |
arg_name = arg.arg | |
if arg_name in name_to_type: | |
annotation_to_type[arg_annotation_str] = name_to_type[arg_name] | |
# If there is a valid return annotation, include it in annotation_to_type. As with argument annotations, | |
# the literal annotation has to be convertible to a string by get_annotation_str, and the actual type | |
# of the annotation cannot be a string. | |
literal_return_annotation = get_annotation_str(f.returns) | |
valid_literal_annotation = literal_return_annotation is not None | |
return_annotation = signature.return_annotation | |
valid_return_annotation_type = ( | |
return_annotation is not inspect.Parameter.empty | |
and not isinstance(return_annotation, str) | |
) | |
if valid_literal_annotation and valid_return_annotation_type: | |
annotation_to_type[literal_return_annotation] = return_annotation | |
return annotation_to_type | |
def createResolutionCallbackForClassMethods(cls): | |
""" | |
This looks at all the methods defined in a class and pulls their closed-over | |
variables into a dictionary and uses that to resolve variables. | |
""" | |
# cls is a type here, so `ismethod` is false since the methods on the type | |
# aren't bound to anything, so Python treats them as regular functions | |
fns = [ | |
getattr(cls, name) | |
for name in cls.__dict__ | |
if inspect.isroutine(getattr(cls, name)) | |
] | |
# Skip built-ins, as they do not have global scope nor type hints | |
# Needed to support `enum.Enum` derived classes in Python-3.11 | |
# That adds `_new_member_` property which is an alias to `__new__` | |
fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")] | |
captures = {} | |
for fn in fns: | |
captures.update(get_closure(fn)) | |
captures.update(get_type_hint_captures(fn)) | |
def lookup_in_class(key): | |
if key in captures: | |
return captures[key] | |
else: | |
return getattr(builtins, key, None) | |
return lookup_in_class | |
def boolean_dispatch( | |
arg_name, arg_index, default, if_true, if_false, module_name, func_name | |
): | |
""" | |
Dispatches to either of 2 script functions based on a boolean argument. | |
In TorchScript, the boolean argument must be constant so that the correct | |
function to use can be determined at compile time. | |
""" | |
def fn(*args, **kwargs): | |
dispatch_flag = default | |
if arg_name in kwargs: | |
dispatch_flag = kwargs[arg_name] | |
elif arg_index < len(args): | |
dispatch_flag = args[arg_index] | |
if dispatch_flag: | |
return if_true(*args, **kwargs) | |
else: | |
return if_false(*args, **kwargs) | |
if if_true.__doc__ is None and if_false.__doc__ is not None: | |
doc = if_false.__doc__ | |
if_true.__doc__ = doc | |
elif if_false.__doc__ is None and if_true.__doc__ is not None: | |
doc = if_true.__doc__ | |
if_false.__doc__ = doc | |
elif if_false.__doc__ is None and if_true.__doc__ is None: | |
# neither function has a docstring | |
doc = None | |
else: | |
raise RuntimeError("only one function can have a docstring") | |
fn.__doc__ = doc | |
if module_name is not None: | |
fn.__module__ = module_name | |
if func_name is not None: | |
fn.__name__ = func_name | |
boolean_dispatched[fn] = { | |
"if_true": if_true, | |
"if_false": if_false, | |
"index": arg_index, | |
"default": default, | |
"arg_name": arg_name, | |
} | |
return fn | |
class FunctionModifiers: | |
""" | |
Used to denote the behavior of a function in TorchScript. See export() and | |
ignore() for details. | |
""" | |
UNUSED = "unused (ignored and replaced with raising of an exception)" | |
IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)" | |
EXPORT = "export (compile this function even if nothing calls it)" | |
DEFAULT = "default (compile if called from a exported function / forward)" | |
COPY_TO_SCRIPT_WRAPPER = ( | |
"if this method is not scripted, copy the python method onto the scripted model" | |
) | |
_DROP = "_drop (function is fully ignored, declaration can be unscriptable)" | |
def export(fn): | |
""" | |
This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a | |
:class:`ScriptModule` and should be compiled. | |
``forward`` implicitly is assumed to be an entry point, so it does not need this decorator. | |
Functions and methods called from ``forward`` are compiled as they are seen | |
by the compiler, so they do not need this decorator either. | |
Example (using ``@torch.jit.export`` on a method): | |
.. testcode:: | |
import torch | |
import torch.nn as nn | |
class MyModule(nn.Module): | |
def implicitly_compiled_method(self, x): | |
return x + 99 | |
# `forward` is implicitly decorated with `@torch.jit.export`, | |
# so adding it here would have no effect | |
def forward(self, x): | |
return x + 10 | |
@torch.jit.export | |
def another_forward(self, x): | |
# When the compiler sees this call, it will compile | |
# `implicitly_compiled_method` | |
return self.implicitly_compiled_method(x) | |
def unused_method(self, x): | |
return x - 20 | |
# `m` will contain compiled methods: | |
# `forward` | |
# `another_forward` | |
# `implicitly_compiled_method` | |
# `unused_method` will not be compiled since it was not called from | |
# any compiled methods and wasn't decorated with `@torch.jit.export` | |
m = torch.jit.script(MyModule()) | |
""" | |
fn._torchscript_modifier = FunctionModifiers.EXPORT | |
return fn | |
def unused(fn): | |
""" | |
This decorator indicates to the compiler that a function or method should | |
be ignored and replaced with the raising of an exception. This allows you | |
to leave code in your model that is not yet TorchScript compatible and still | |
export your model. | |
Example (using ``@torch.jit.unused`` on a method):: | |
import torch | |
import torch.nn as nn | |
class MyModule(nn.Module): | |
def __init__(self, use_memory_efficient): | |
super().__init__() | |
self.use_memory_efficient = use_memory_efficient | |
@torch.jit.unused | |
def memory_efficient(self, x): | |
import pdb | |
pdb.set_trace() | |
return x + 10 | |
def forward(self, x): | |
# Use not-yet-scriptable memory efficient mode | |
if self.use_memory_efficient: | |
return self.memory_efficient(x) | |
else: | |
return x + 10 | |
m = torch.jit.script(MyModule(use_memory_efficient=False)) | |
m.save("m.pt") | |
m = torch.jit.script(MyModule(use_memory_efficient=True)) | |
# exception raised | |
m(torch.rand(100)) | |
""" | |
if isinstance(fn, property): | |
prop = fn | |
setattr( # noqa: B010 | |
prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED | |
) | |
if prop.fset: | |
setattr( # noqa: B010 | |
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED | |
) | |
return prop | |
fn._torchscript_modifier = FunctionModifiers.UNUSED | |
return fn | |
# No op context manager from python side | |
class _IgnoreContextManager(contextlib.AbstractContextManager): | |
def __init__(self, **kwargs): | |
pass | |
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: | |
pass | |
def ignore(drop=False, **kwargs): | |
""" | |
This decorator indicates to the compiler that a function or method should | |
be ignored and left as a Python function. This allows you to leave code in | |
your model that is not yet TorchScript compatible. If called from TorchScript, | |
ignored functions will dispatch the call to the Python interpreter. Models with ignored | |
functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead. | |
Example (using ``@torch.jit.ignore`` on a method):: | |
import torch | |
import torch.nn as nn | |
class MyModule(nn.Module): | |
@torch.jit.ignore | |
def debugger(self, x): | |
import pdb | |
pdb.set_trace() | |
def forward(self, x): | |
x += 10 | |
# The compiler would normally try to compile `debugger`, | |
# but since it is `@ignore`d, it will be left as a call | |
# to Python | |
self.debugger(x) | |
return x | |
m = torch.jit.script(MyModule()) | |
# Error! The call `debugger` cannot be saved since it calls into Python | |
m.save("m.pt") | |
Example (using ``@torch.jit.ignore(drop=True)`` on a method): | |
.. testcode:: | |
import torch | |
import torch.nn as nn | |
class MyModule(nn.Module): | |
@torch.jit.ignore(drop=True) | |
def training_method(self, x): | |
import pdb | |
pdb.set_trace() | |
def forward(self, x): | |
if self.training: | |
self.training_method(x) | |
return x | |
m = torch.jit.script(MyModule()) | |
# This is OK since `training_method` is not saved, the call is replaced | |
# with a `raise`. | |
m.save("m.pt") | |
.. testcleanup:: | |
import os | |
os.remove('m.pt') | |
""" | |
if callable(drop): | |
# used without any args, so drop is actually a function | |
# @torch.jit.ignore | |
# def fn(...): | |
fn = drop | |
fn._torchscript_modifier = FunctionModifiers.IGNORE | |
return fn | |
if not isinstance(drop, bool): | |
raise RuntimeError( | |
"Argument to @torch.jit.ignore must be a bool or " | |
f"a function but got {drop}" | |
) | |
# for backwards compat | |
drop_on_export = kwargs.pop("drop_on_export", None) | |
if drop_on_export: | |
warnings.warn( | |
"ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function " | |
"call on compilation. Use torch.jit.unused now. {}", | |
category=FutureWarning, | |
) | |
drop = drop_on_export | |
elif drop: | |
warnings.warn( | |
"ignore(True) has been deprecated. TorchScript will now drop the function " | |
"call on compilation. Use torch.jit.unused now. {}", | |
category=FutureWarning, | |
) | |
def decorator(fn): | |
if drop: | |
fn._torchscript_modifier = FunctionModifiers.UNUSED | |
else: | |
fn._torchscript_modifier = FunctionModifiers.IGNORE | |
return fn | |
return decorator | |
def _drop(fn): | |
fn._torchscript_modifier = FunctionModifiers._DROP | |
return fn | |
def _copy_to_script_wrapper(fn): | |
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER | |
return fn | |
def module_has_exports(mod): | |
for name in dir(mod): | |
if hasattr(mod, name): | |
item = getattr(mod, name) | |
if callable(item): | |
if get_torchscript_modifier(item) is FunctionModifiers.EXPORT: | |
return True | |
return False | |
# WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you | |
# rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to | |
# allow JIT'd code to still be covered. | |
def should_drop(fn) -> bool: | |
attr = get_torchscript_modifier(fn) | |
if attr is None: | |
return False | |
return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP | |
def is_ignored_fn(fn) -> bool: | |
mod = get_torchscript_modifier(fn) | |
return ( | |
mod is FunctionModifiers.UNUSED | |
or mod is FunctionModifiers.IGNORE | |
or mod is FunctionModifiers._DROP | |
) | |
def _is_drop_fn(fn) -> bool: | |
mod = get_torchscript_modifier(fn) | |
return mod is FunctionModifiers._DROP | |
def is_static_fn(cls, fn) -> bool: | |
return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod) | |
def get_static_fn(cls, fn): | |
return inspect.getattr_static(cls, fn).__func__ | |
def get_torchscript_modifier(fn): | |
if not callable(fn): | |
return None | |
if hasattr(fn, "__func__"): | |
fn = fn.__func__ | |
return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT) | |
def copy_torchscript_modifier(orig, new) -> None: | |
attr = get_torchscript_modifier(orig) | |
if attr is None: | |
return | |
new._torchscript_modifier = attr | |
# overloading registration | |
# overloads get registered in this file, and compiled in torch/jit/__init__.py | |
# so that they can be imported in nn/functional.py without an import cycle | |
# qualified_name => list[overload_functions] | |
_overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484 | |
_OVERLOAD_EXAMPLE = """ | |
Example usage of overload function: | |
@torch.jit._overload | |
def my_function(x: type0) -> type0: # decl 1 | |
pass | |
@torch.jit._overload | |
def my_function(x: type1) -> type1: # decl 2 | |
pass | |
def my_function(x): # implementation | |
if isinstance(x, type0): | |
return x | |
elif isinstance(x, type1): | |
return x | |
""" | |
def get_overload_no_implementation_error_message(kind, obj): | |
sourcelines, file_lineno, filename = get_source_lines_and_file(obj) | |
return ( | |
f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make ' | |
f"sure a definition is provided and defined after all overload declarations.\n" | |
f'File "{filename}", line {file_lineno}:\n' | |
+ "".join(sourcelines) | |
+ "\n" | |
+ _OVERLOAD_EXAMPLE | |
) | |
def _check_overload_body(func): | |
try: | |
parsed_def = parse_def(func) | |
except OSError as e: | |
# Parsing the function definition can raise an OSError if source is unavailable. | |
# Since this is just an initial check, just raise a warning if this is the case. | |
warnings.warn( | |
f"Unable to retrieve source for @torch.jit._overload function: {func}." | |
) | |
return | |
body = parsed_def.ast.body[0].body | |
def is_pass(x): | |
return isinstance(x, ast.Pass) | |
def is_ellipsis(x): | |
return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis) | |
if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])): | |
msg = ( | |
"Only `pass` statement or `...` can be the body of overload declaration:\n" | |
) | |
msg += "\n".join(parsed_def.source.split("\n")[:3]) | |
msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE | |
raise RuntimeError(msg) | |
def _overload(func): | |
_check_overload_body(func) | |
qual_name = _qualified_name(func) | |
global _overloaded_fns | |
fn_overload_list = _overloaded_fns.get(qual_name) | |
if fn_overload_list is None: | |
fn_overload_list = [] | |
_overloaded_fns[qual_name] = fn_overload_list | |
fn_overload_list.append(func) | |
return func | |
def _get_fn_overloads(qual_name): | |
return _overloaded_fns.get(qual_name) | |
def _clear_fn_overloads(qual_name) -> None: | |
del _overloaded_fns[qual_name] | |
def get_class_name_lineno(method) -> Tuple[str, int]: | |
current_frame = inspect.currentframe() | |
# one for the get_class_name call, one for _overload_method call | |
for i in range(2): | |
assert ( | |
current_frame is not None | |
) # assert current frame is not an Optional[FrameType] | |
current_frame = current_frame.f_back | |
assert current_frame is not None # same here | |
class_name = current_frame.f_code.co_name | |
line_no = current_frame.f_code.co_firstlineno | |
return class_name, line_no | |
# At the point the decorator is applied to class methods the method | |
# has no reference to its owning class. _qualified_name would not include | |
# the class it is defined in, so any methods with the same name in the same file | |
# would have the same _qualified_name, even if they were defined in different | |
# classes. This problem only exists in python 2. | |
# We get around this problem by looking at the stack frame and identifying | |
# the class name, and throwing an error whenever overloads are used | |
# when modules of the same name are in the same file | |
# qualified_name => class name => list[overload_functions] | |
_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484 | |
# (qualified_name, class name) => class_fileno | |
_overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {} | |
def _overload_method(func): | |
_check_overload_body(func) | |
qual_name = _qualified_name(func) | |
global _overloaded_methods | |
class_name_map = _overloaded_methods.get(qual_name, None) | |
if class_name_map is None: | |
class_name_map = {} | |
_overloaded_methods[qual_name] = class_name_map | |
class_name, line_no = get_class_name_lineno(func) | |
method_overloads = class_name_map.get(class_name, None) | |
if method_overloads is None: | |
method_overloads = [] | |
class_name_map[class_name] = method_overloads | |
_overloaded_method_class_fileno[(qual_name, class_name)] = line_no | |
else: | |
existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)] | |
if existing_lineno != line_no: | |
raise RuntimeError( | |
"Cannot currently overload the same method name in two different" | |
" classes with the same name in the same module" | |
) | |
method_overloads.append(func) | |
return func | |
def _get_overloaded_methods(method, mod_class): | |
# TODO: __name__ not set for submodules in recursive script | |
if not hasattr(method, "__name__"): | |
return None | |
qual_name = _qualified_name(method) | |
class_name_map = _overloaded_methods.get(qual_name, None) | |
if class_name_map is None: | |
return None | |
overloads = class_name_map.get(mod_class.__name__, None) | |
if overloads is None: | |
return None | |
method_line_no = get_source_lines_and_file(method)[1] | |
mod_class_fileno = get_source_lines_and_file(mod_class)[1] | |
mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0]) | |
if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno): | |
raise Exception( | |
"Overloads are not useable when a module is redeclared within the same file: " | |
+ str(method) | |
) | |
return overloads | |
def is_tuple(ann) -> bool: | |
if ann is Tuple: | |
raise_error_container_parameter_missing("Tuple") | |
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule | |
if not hasattr(ann, "__module__"): | |
return False | |
ann_origin = get_origin(ann) | |
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple: | |
return True | |
return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple) | |
def is_list(ann) -> bool: | |
if ann is List: | |
raise_error_container_parameter_missing("List") | |
if not hasattr(ann, "__module__"): | |
return False | |
ann_origin = get_origin(ann) | |
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list: | |
return True | |
return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list) | |
def is_dict(ann) -> bool: | |
if ann is Dict: | |
raise_error_container_parameter_missing("Dict") | |
if not hasattr(ann, "__module__"): | |
return False | |
ann_origin = get_origin(ann) | |
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict: | |
return True | |
return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict) | |
def is_union(ann): | |
if ann is Union: | |
raise_error_container_parameter_missing("Union") | |
return isinstance(ann, BuiltinUnionType) or ( | |
hasattr(ann, "__module__") | |
and ann.__module__ == "typing" | |
and (get_origin(ann) is Union) | |
) | |
def is_optional(ann): | |
if ann is Optional: | |
raise_error_container_parameter_missing("Optional") | |
def is_optional_as_optional(ann): | |
return ( | |
hasattr(ann, "__module__") | |
and ann.__module__ == "typing" | |
and (get_origin(ann) is Optional) | |
) | |
def is_union_as_optional(ann): | |
ann_args = get_args(ann) | |
return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args) | |
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann)) | |
def is_future(ann) -> bool: | |
if ann is Future: | |
raise RuntimeError( | |
"Attempted to use Future without a " | |
"contained type. Please add a contained type, e.g. " | |
"Future[int]" | |
) | |
return get_origin(ann) is Future | |
def is_await(ann) -> bool: | |
if ann is _Await: | |
return True | |
return get_origin(ann) is _Await | |
if torch.distributed.rpc.is_available(): | |
from torch._C._distributed_rpc import PyRRef | |
from torch.distributed.rpc import RRef | |
def is_rref(ann) -> bool: | |
if ann is RRef: | |
raise RuntimeError( | |
"Attempted to use RRef without a " | |
"contained type. Please add a contained type, e.g. " | |
"RRef[int]" | |
) | |
return get_origin(ann) is RRef | |
def is_rref_instance(obj) -> bool: | |
return isinstance(obj, PyRRef) | |
else: | |
def is_rref_instance(obj) -> bool: | |
# If the RPC module doesn't exist then RRefs don't exist either. | |
return False | |
def is_final(ann) -> bool: | |
return ( | |
hasattr(ann, "__module__") | |
and ann.__module__ in {"typing", "typing_extensions"} | |
and (get_origin(ann) is Final or isinstance(ann, type(Final))) | |
) | |
# allows BroadcastingList instance to be subscriptable | |
class BroadcastingListCls: | |
def __getitem__(self, types): | |
return | |
# mypy doesn't support parameters on types, so we have to explicitly type each | |
# list size | |
BroadcastingList1 = BroadcastingListCls() | |
for i in range(2, 7): | |
globals()[f"BroadcastingList{i}"] = BroadcastingList1 | |
def is_scripting() -> bool: | |
r""" | |
Function that returns True when in compilation and False otherwise. This | |
is useful especially with the @unused decorator to leave code in your | |
model that is not yet TorchScript compatible. | |
.. testcode:: | |
import torch | |
@torch.jit.unused | |
def unsupported_linear_op(x): | |
return x | |
def linear(x): | |
if torch.jit.is_scripting(): | |
return torch.linear(x) | |
else: | |
return unsupported_linear_op(x) | |
""" | |
return False | |
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj. | |
def _qualified_name(obj, mangle_name=True) -> str: | |
# This special case allows us to override the qualified name on a type. | |
# It's currently used in conjunction with tracing, where we create a | |
# fake module to filter only supported attributes. However, since this | |
# new type is defined as a local class, we need a mechanism to override | |
# its qualname so it appears correctly in the TorchScript system. This, | |
# we set '_jit_override_qualname' with the original traced module's | |
# qualified name, which is picked up here | |
if hasattr(obj, "_jit_override_qualname"): | |
return obj._jit_override_qualname | |
# short-circuit in cases where the object already has a known qualified name | |
if isinstance(obj, torch._C.ScriptFunction): | |
return obj.qualified_name | |
if getattr(obj, "__name__", None): | |
name = obj.__name__ | |
# Enum classes do not have `__name__` attr, instead they have `name`. | |
elif isinstance(obj, enum.Enum): | |
name = obj.name | |
else: | |
raise RuntimeError("Could not get name of python class object") | |
if name == "<lambda>": | |
name = "_lambda" # make name a valid identifier | |
module_name = obj.__module__ | |
# If the module is actually a torchbind module, then we should short circuit | |
if module_name == "torch._classes": | |
return obj.qualified_name | |
# The Python docs are very clear that `__module__` can be None, but I can't | |
# figure out when it actually would be. | |
if module_name is None: | |
raise RuntimeError( | |
f"Could not get qualified name for class '{name}': " | |
"__module__ can't be None." | |
) | |
# if getattr(sys.modules[module_name], name) is not obj: | |
# raise RuntimeError(f"Could not get qualified name for class '{name}': " | |
# f"the attr {name} on module {module_name} is not the class") | |
# torch.package and TorchScript have separate mangling schemes to avoid | |
# name collisions from multiple packages. To avoid them interfering with | |
# each other, normalize the package manging here. | |
if package_mangling.is_mangled(module_name): | |
module_name = module_name.replace("<", "_") | |
module_name = module_name.replace(">", "_") | |
# The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h | |
# does not need mangle the python class name. | |
if mangle_name: | |
# __main__ is a builtin module, so rewrite it to "__torch__". | |
if module_name == "__main__": | |
module_name = "__torch__" | |
else: | |
# Everything else gets a "__torch__" prefix to avoid name collisions | |
# with the names of user values. | |
module_name = "__torch__." + module_name | |
if "." in name: | |
raise RuntimeError( | |
f"Could not get qualified name for class '{name}': " | |
f"'{name}' is not a valid identifier" | |
) | |
return module_name + "." + name | |
def _try_get_dispatched_fn(fn): | |
if not callable(fn): | |
return None | |
return boolean_dispatched.get(fn) | |
def _get_named_tuple_properties( | |
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None | |
): | |
if loc is None: | |
loc = fake_range() | |
assert issubclass(obj, tuple) and hasattr(obj, "_fields") | |
if hasattr(obj, "_field_defaults"): | |
defaults = [ | |
obj._field_defaults[field] | |
for field in obj._fields | |
if field in obj._field_defaults | |
] | |
else: | |
defaults = [] | |
# In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function | |
# Also, annotations from base class are not inherited so they need to be queried explicitly | |
if sys.version_info[:2] < (3, 10): | |
obj_annotations = getattr(obj, "__annotations__", {}) | |
else: | |
obj_annotations = inspect.get_annotations(obj) | |
if len(obj_annotations) == 0 and hasattr(obj, "__base__"): | |
obj_annotations = inspect.get_annotations(obj.__base__) | |
annotations = [] | |
for field in obj._fields: | |
if field in obj_annotations: | |
field_type = obj_annotations[field] | |
# [Note: ForwardRef annotations in NamedTuple attributes] | |
# NamedTuple types are slightly different from normal types. | |
# | |
# Normally, annotations are evaluted like this (during jit.script): | |
# 1. Load strings of python code into c++ and parse. | |
# 2. Get annotations as strings | |
# 3. Use the PythonResolver's resolution callback (rcb) to convert | |
# the string into a python object | |
# 4. We call into annotations.py:ann_to_type to convert python obj | |
# from step 3 into a type that torchscript understands. | |
# | |
# NamedTuples are more complicated, because it has sub-types. | |
# Normally, once we have the NamedTuple type object from #3, | |
# we can just look at the annotation literal values and use | |
# ann_to_type directly on them. | |
# | |
# But sometimes, users will annotate with string literals, e.g. | |
# x: 'int' | |
# This also happens with PEP563 (from __forward__ import annotations) | |
# | |
# These annotations appear in the annotation dict as ForwardRef('int'). | |
# | |
# Then, we need to convert the string into a python object. This | |
# requires having local context for custom objects or imported types. | |
# rcb() is what gives us this. So, we plumb rcb through the stack so | |
# it can be used in this context for the if block below. | |
# | |
# FAQ: | |
# - Why do we need this special handling for NamedTuple but string | |
# annotations work fine for normal types? Normally, we parse the | |
# string directly and then call rcb() directly from C++. | |
# - Why not use ForwardRef._evaluate? For that, we need globals() | |
# and locals() for the local context where the NamedTuple was defined. | |
# rcb is what lets us look up into these. So, basically rcb does the | |
# hard work for us. | |
if isinstance(field_type, ForwardRef) and rcb is not None: | |
rcb_type = rcb(field_type.__forward_arg__) | |
# rcb returns None if it can't find anything. | |
if rcb_type is None: | |
raise ValueError( | |
f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}." | |
f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858." | |
f" Issue occurred at {loc.highlight()}" | |
) | |
field_type = rcb_type | |
the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb) | |
annotations.append(the_type) | |
else: | |
annotations.append(torch._C.TensorType.getInferred()) | |
return type(obj).__name__, obj._fields, annotations, defaults | |
def _create_named_tuple( | |
t, unqual_name: str, field_names: List[str], defaults: Tuple[Any, ...] | |
): | |
TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc] | |
return TupleType(*t) | |
def _disable_emit_hooks(): | |
hooks = torch._C._jit_get_emit_hooks() | |
torch._C._jit_set_emit_hooks(None, None) | |
try: | |
yield | |
finally: | |
torch._C._jit_set_emit_hooks(hooks[0], hooks[1]) | |
def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811 | |
def __enter__(self) -> None: | |
self.hooks = torch._C._jit_get_emit_hooks() | |
torch._C._jit_set_emit_hooks(None, None) | |
def __exit__(self, *args) -> None: | |
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1]) | |
def _is_exception(obj) -> bool: | |
if not inspect.isclass(obj): | |
return False | |
return issubclass(obj, Exception) | |
def raise_error_container_parameter_missing(target_type) -> None: | |
if target_type == "Dict": | |
raise RuntimeError( | |
"Attempted to use Dict without " | |
"contained types. Please add contained type, e.g. " | |
"Dict[int, int]" | |
) | |
raise RuntimeError( | |
f"Attempted to use {target_type} without a " | |
"contained type. Please add a contained type, e.g. " | |
f"{target_type}[int]" | |
) | |
def check_args_exist(target_type) -> None: | |
if target_type is List or target_type is list: | |
raise_error_container_parameter_missing("List") | |
elif target_type is Tuple or target_type is tuple: | |
raise_error_container_parameter_missing("Tuple") | |
elif target_type is Dict or target_type is dict: | |
raise_error_container_parameter_missing("Dict") | |
elif target_type is None or target_type is Optional: | |
raise_error_container_parameter_missing("Optional") | |
def check_empty_containers(obj) -> None: | |
if obj == [] or obj == {} or obj == (): | |
warnings.warn( | |
"The inner type of a container is lost when " | |
"calling torch.jit.isinstance in eager mode. For " | |
"example, List[int] would become list and " | |
"therefore falsely return True for List[float] or" | |
" List[str]." | |
) | |
# supports List/Dict/Tuple and Optional types | |
# TODO support future | |
def container_checker(obj, target_type) -> bool: | |
origin_type = get_origin(target_type) | |
check_args_exist(target_type) | |
if origin_type is None: | |
return False | |
elif origin_type is list or origin_type is List: | |
check_empty_containers(obj) | |
if not isinstance(obj, list): | |
return False | |
arg_type = get_args(target_type)[0] | |
arg_origin = get_origin(arg_type) | |
for el in obj: | |
# check if nested container, ex: List[List[str]] | |
if arg_origin: # processes nested container, ex: List[List[str]] | |
if not container_checker(el, arg_type): | |
return False | |
elif not isinstance(el, arg_type): | |
return False | |
return True | |
elif origin_type is Dict or origin_type is dict: | |
check_empty_containers(obj) | |
if not isinstance(obj, dict): | |
return False | |
key_type = get_args(target_type)[0] | |
val_type = get_args(target_type)[1] | |
for key, val in obj.items(): | |
# check if keys are of right type | |
if not isinstance(key, key_type): | |
return False | |
val_origin = get_origin(val_type) | |
if val_origin: | |
if not container_checker(val, val_type): | |
return False | |
elif not isinstance(val, val_type): | |
return False | |
return True | |
elif origin_type is Tuple or origin_type is tuple: | |
check_empty_containers(obj) | |
if not isinstance(obj, tuple): | |
return False | |
arg_types = get_args(target_type) | |
if len(obj) != len(arg_types): | |
return False | |
for el, el_type in zip(obj, arg_types): | |
el_origin = get_origin(el_type) | |
if el_origin: | |
if not container_checker(el, el_type): | |
return False | |
elif not isinstance(el, el_type): | |
return False | |
return True | |
elif origin_type is Union or issubclass( | |
origin_type, BuiltinUnionType | |
): # also handles Optional | |
if obj is None: # check before recursion because None is always fine | |
return True | |
inner_types = get_args(target_type) | |
for t in inner_types: | |
t_origin = get_origin(t) | |
if t_origin: | |
return container_checker(obj, t) | |
elif isinstance(obj, t): | |
return True | |
return False | |
def _isinstance(obj, target_type) -> bool: | |
if isinstance(target_type, collections.abc.Container): | |
if not isinstance(target_type, tuple): | |
raise RuntimeError( | |
"The second argument to " | |
"`torch.jit.isinstance` must be a type " | |
"or a tuple of types" | |
) | |
for t_type in target_type: | |
if _isinstance(obj, t_type): | |
return True | |
return False | |
origin_type = get_origin(target_type) | |
if origin_type: | |
return container_checker(obj, target_type) | |
# Check to handle non-typed optional origin returns as none instead | |
# of as optional in 3.7-3.8 | |
check_args_exist(target_type) | |
# handle non-containers | |
return isinstance(obj, target_type) | |
class _TensorExtractor(pickle.Pickler): | |
def __init__(self, *args, tensors: List[torch.Tensor], **kwargs): | |
super().__init__(*args, **kwargs) | |
self.tensors = tensors | |
def persistent_id(self, obj): | |
if isinstance(obj, torch.Tensor): | |
self.tensors.append(obj) | |
return "" | |
# Since we just want to extract tensors, we don't mind if an object is | |
# unpicklable if it doesn't contain tensors, as we can just ignore/skip | |
# it. To play it safe, we only do so for common objects that we're sure | |
# don't contain tensors. Feel free to add new types here. Note also that | |
# even if a type isn't listed here this won't block users, since thet | |
# can just add a __getstate__ or __reduce__ method to their class. | |
if isinstance(obj, LockType): | |
return "" | |
# Futures and RRefs don't technically contain a value, they just offer | |
# the means to access a value. | |
if isinstance(obj, CFuture) or is_rref_instance(obj): | |
return "" | |
if isinstance(obj, CAwait): | |
return "" | |
if isinstance(obj, torch.cuda.Event): | |
return "" | |
if isinstance(obj, threading.Thread): | |
return "" | |
return None | |
def _extract_tensors(obj): | |
r""" | |
This function is exclusively called from C++. | |
See ``torch/csrc/jit/python/python_ivalue.h``. | |
It extracts the tensors contained in the given object, through pickling. | |
""" | |
tensors: List[torch.Tensor] = [] | |
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors) | |
extractor.dump(obj) | |
return tensors | |
# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass | |
# that were previously dropped. To preserve the behavior, explicitly drop them there | |
if sys.version_info > (3, 10): | |
_drop(enum.Enum.__new__) | |
_drop(enum.Enum.__format__) | |
_drop(enum.Enum.__repr__) | |
_drop(enum.Enum.__str__) | |