Spaces:
Sleeping
Sleeping
import collections | |
import functools | |
import inspect | |
import sys | |
import textwrap | |
import types | |
import warnings | |
from typing import Dict, List, Set, Type | |
import torch | |
import torch._jit_internal as _jit_internal | |
from torch._sources import fake_range | |
from torch.jit._builtins import _find_builtin | |
from torch.jit._check import AttributeTypeIsSupportedChecker | |
from torch.jit._state import _add_script_class, _get_script_class, _python_cu | |
from torch.jit.frontend import ( | |
get_class_properties, | |
get_default_args, | |
get_jit_class_def, | |
get_jit_def, | |
) | |
from torch.nn import Module | |
ScriptMethodStub = collections.namedtuple( | |
"ScriptMethodStub", ("resolution_callback", "def_", "original_method") | |
) | |
PropertyStub = collections.namedtuple("PropertyStub", ("resolution_callback", "def_")) | |
# TODO: there should be a more principled way of doing this. | |
ignored_attributes = [ | |
"_version", | |
"_parameters", | |
"_buffers", | |
"_non_persistent_buffers_set", | |
"_backward_hooks", | |
"_backward_pre_hooks", | |
"_forward_hooks", | |
"_forward_hooks_with_kwargs", | |
"_forward_pre_hooks", | |
"_forward_pre_hooks_with_kwargs", | |
"_forward_hooks_always_called", | |
"_state_dict_hooks", | |
"_state_dict_pre_hooks", | |
"_load_state_dict_pre_hooks", | |
"_load_state_dict_post_hooks", | |
"_modules", | |
"_initializing", | |
"dump_patches", | |
] | |
def _compile_and_register_class(obj, rcb, qualified_name): | |
script_class = _get_script_class(obj) | |
if not script_class: | |
ast = get_jit_class_def(obj, obj.__name__) | |
defaults = torch.jit.frontend.get_default_args_for_class(obj) | |
script_class = torch._C._jit_script_class_compile( | |
qualified_name, ast, defaults, rcb | |
) | |
_add_script_class(obj, script_class) | |
return script_class | |
def make_stub(func, name): | |
rcb = _jit_internal.createResolutionCallbackFromClosure(func) | |
ast = get_jit_def(func, name, self_name="RecursiveScriptModule") | |
return ScriptMethodStub(rcb, ast, func) | |
def make_stub_from_method(nn_module, method_name): | |
func = getattr(nn_module, method_name) | |
if isinstance(func, ScriptMethodStub): | |
return func | |
# Make sure the name present in the resulting AST will match the name | |
# requested here. The only time they don't match is if you do something | |
# like: | |
# def _forward(self): | |
# pass | |
# forward = _forward | |
# In this case, the actual function object will have the name `_forward`, | |
# even though we requested a stub for `forward`. | |
return make_stub(func, method_name) | |
def make_stubs_from_exported_methods(mod): | |
stubs = [] | |
for name in dir(mod): | |
item = getattr(mod, name, None) | |
if ( | |
_jit_internal.get_torchscript_modifier(item) | |
is _jit_internal.FunctionModifiers.EXPORT | |
): | |
stubs.append(make_stub_from_method(mod, name)) | |
return stubs | |
def jit_ignored_properties(module): | |
user_annotated_ignored_attributes = getattr( | |
module, "__jit_ignored_attributes__", list() | |
) | |
def get_properties_names(module): | |
return {k for k, v in vars(module).items() if isinstance(v, property)} | |
properties = get_properties_names(type(module)) | |
user_annoted_ignored_properties = set() | |
for ignored_attr in user_annotated_ignored_attributes: | |
if ignored_attr in properties: | |
user_annoted_ignored_properties.add(ignored_attr) | |
return user_annoted_ignored_properties | |
# base types that can be constants | |
# in addition, tuples and lists of these base types are also considered constants | |
# If you edit this list, then you also need to edit the handlers in | |
# ConstantValue in jit/script/init.cpp | |
_constant_types = ( | |
bool, | |
float, | |
int, | |
str, | |
type(None), | |
torch.device, | |
torch.layout, | |
torch.dtype, | |
) | |
def _get_valid_constant(attr, v, owner_type): | |
if isinstance(v, _constant_types): | |
return v | |
elif isinstance(v, (tuple, list)): | |
return tuple(_get_valid_constant(attr, x, owner_type) for x in v) | |
constants = ", ".join(torch.typename(typ) for typ in _constant_types) | |
raise TypeError( | |
textwrap.dedent( | |
f""" | |
'{torch.typename(type(v))}' object in attribute '{owner_type}.{attr}' is not a valid constant. | |
Valid constants are: | |
1. a nn.ModuleList | |
2. a value of type {{{constants}}} | |
3. a list or tuple of (2) | |
""" | |
) | |
) | |
class SourceContext(torch._C._jit_tree_views.SourceRangeFactory): | |
def __init__(self, source, filename, file_lineno, leading_whitespace_len): | |
super().__init__(source, filename, file_lineno, leading_whitespace_len) | |
def get_annotations(obj): | |
if sys.version_info < (3, 10): | |
return getattr(obj, "__annotations__", {}) | |
# In Python-3.10+ it is recommended to use inspect.get_annotations | |
# See https://docs.python.org/3.10/howto/annotations.html | |
# But also, in 3.10 annotations from base class are not inherited | |
# by unannotated derived one, so they must be manually extracted | |
annotations = inspect.get_annotations(obj) | |
if annotations: | |
return annotations | |
def get_cls_annotations(cls): | |
cls_annotations = inspect.get_annotations(cls) | |
if cls_annotations: | |
return cls_annotations | |
for base in cls.__bases__: | |
cls_annotations = get_cls_annotations(base) | |
if cls_annotations: | |
return cls_annotations | |
return {} | |
cls = obj if isinstance(obj, type) else type(obj) | |
return get_cls_annotations(cls) | |
def infer_concrete_type_builder(nn_module, share_types=True): | |
""" | |
Build a ConcreteModuleTypeBuilder from an nn.Module. | |
This ConcreteModuleType doesn't have a JIT type associated with it yet, it | |
must be filled in by the caller. | |
""" | |
concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module)) | |
if isinstance(nn_module, (torch.nn.ModuleDict)): | |
concrete_type_builder.set_module_dict() | |
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)): | |
concrete_type_builder.set_module_list() | |
if isinstance(nn_module, (torch.nn.ParameterList)): | |
concrete_type_builder.set_parameter_list() | |
if isinstance(nn_module, (torch.nn.ParameterDict)): | |
concrete_type_builder.set_parameter_dict() | |
class_annotations = get_annotations(nn_module) | |
if isinstance(nn_module, (torch.ao.quantization.QuantWrapper)): | |
class_annotations = {} | |
# Get user-annotated ignored attributes. | |
user_annotated_ignored_attributes = getattr( | |
nn_module, "__jit_ignored_attributes__", list() | |
) | |
concrete_type_builder.add_ignored_attributes(user_annotated_ignored_attributes) | |
ignored_properties = jit_ignored_properties(nn_module) | |
# try to infer the type from type annotation or from the object itself | |
def infer_type(name, item): | |
# The forward function from Module is special; never use this annotations; we | |
# need to infer type directly using JIT. I originally wanted to write | |
# this test as isinstance(class_annotations[name], Callable) but | |
# isinstance on typing things doesn't seem to work: isinstance(list, Callable) | |
# is also true! | |
inferred = False | |
try: | |
if ( | |
name in class_annotations | |
and class_annotations[name] | |
!= torch.nn.Module.__annotations__["forward"] | |
): | |
ann_to_type = torch.jit.annotations.ann_to_type( | |
class_annotations[name], fake_range() | |
) | |
attr_type = torch._C.InferredType(ann_to_type) | |
elif isinstance(item, torch.jit.Attribute): | |
ann_to_type = torch.jit.annotations.ann_to_type(item.type, fake_range()) | |
attr_type = torch._C.InferredType(ann_to_type) | |
else: | |
attr_type = torch._C._jit_try_infer_type(item) | |
inferred = True | |
except RuntimeError as re: | |
raise RuntimeError(f"Error inferring type for {name}: {item}: {re}") from re | |
return attr_type, inferred | |
added_names = set() | |
for name, item in nn_module._parameters.items(): | |
if name in user_annotated_ignored_attributes: | |
continue | |
assert item is None or isinstance(item, torch.Tensor) | |
attr_type, _ = infer_type(name, item) | |
# We currently have the invariant in various places in our code | |
# that parameters must be Tensors. However, the nn.Module API also | |
# allows NoneType parameters. These parameters are not returned as | |
# part of `parameters()` and its variants, but are available | |
# through direct attribute access. | |
concrete_type_builder.add_attribute(name, attr_type.type(), True, False) | |
added_names.add(name) | |
for name, item in nn_module._buffers.items(): | |
if name in user_annotated_ignored_attributes: | |
continue | |
assert item is None or isinstance(item, torch.Tensor) | |
attr_type, _ = infer_type(name, item) | |
concrete_type_builder.add_attribute(name, attr_type.type(), False, True) | |
added_names.add(name) | |
for name, item in nn_module._modules.items(): | |
if name in user_annotated_ignored_attributes: | |
continue | |
attr_type, _ = infer_type(name, item) | |
if item is None: | |
# Modules can be None. We don't have direct support for optional | |
# Modules, so the register it as an NoneType attribute instead. | |
concrete_type_builder.add_attribute(name, attr_type.type(), False, False) | |
continue | |
if attr_type.success(): | |
assert attr_type.type().is_interface_type() | |
# if the type can be inferred, it should be a module interface type | |
sub_concrete_type = torch._C.ConcreteModuleType.from_jit_type( | |
attr_type.type() | |
) | |
else: | |
# otherwise we get the concrete module type for item and add it to concrete_type | |
sub_concrete_type = get_module_concrete_type(item, share_types) | |
concrete_type_builder.add_module(name, sub_concrete_type) | |
added_names.add(name) | |
# populate constants_set | |
constants_set = set(getattr(nn_module, "__constants__", ())) | |
# Constants annotated via `Final[T]` rather than being added to `__constants__` | |
for name, ann in class_annotations.items(): | |
if torch._jit_internal.is_final(ann): | |
constants_set.add(name) | |
for name in constants_set: | |
if name in added_names: | |
# TODO: We should really error in this case, but its bc-breaking so | |
# we need to warn for at least one release | |
if name in nn_module._modules: | |
hint = "submodule" | |
elif name in nn_module._buffers: | |
hint = "buffer" | |
elif name in nn_module._parameters: | |
hint = "parameter" | |
else: | |
raise AssertionError( | |
"added_names must be submodule, parameter, or buffer" | |
) | |
warnings.warn( | |
f"'{name}' was found in ScriptModule constants, " | |
f" but it is a non-constant {hint}. Consider removing it." | |
) | |
continue | |
if not hasattr(nn_module, name): | |
# TODO: We should really error in this case, but its bc-breaking so | |
# we need to warn for at least one release | |
warnings.warn( | |
f"'{name}' was found in ScriptModule constants, " | |
"but was not actually set in __init__. " | |
"Consider removing it." | |
) | |
continue | |
value = getattr(nn_module, name) | |
concrete_type_builder.add_constant( | |
name, _get_valid_constant(name, value, type(nn_module).__name__) | |
) | |
added_names.add(name) | |
# populate overloads | |
overloads = getattr(nn_module, "__overloads__", {}) | |
# update with any annotated overloads | |
overloads.update( | |
get_overload_name_mapping( | |
get_overload_annotations(nn_module, ignored_properties) | |
) | |
) | |
for name, overloaded_names in overloads.items(): | |
concrete_type_builder.add_overload(name, overloaded_names) | |
for name, value in nn_module.__dict__.items(): | |
if name in ignored_attributes or name.startswith("__"): | |
# Python objects have lots of random attributes attached to them; | |
# PyTorch adds a few more. Prevent these from getting compiled. | |
continue | |
if name in user_annotated_ignored_attributes: | |
continue | |
if name in added_names: | |
# Don't re-add anything we already added | |
continue | |
isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket) | |
if isoverloadpacket: | |
value = value.op | |
# Handle Python function attributes | |
if inspect.isfunction(value): | |
try: | |
scripted_fn = torch.jit.script(value) | |
concrete_type_builder.add_function_attribute( | |
name, torch._C._jit_try_infer_type(scripted_fn).type(), value | |
) | |
except Exception as e: | |
# If we fail to script the function, it isn't a hard error. | |
# Instead, we will add it to the list of attributes we failed | |
# to convert, with the compilation error. | |
hint = ( | |
"(This function exists as an attribute on the Python module, " | |
"but we failed to compile it to a TorchScript function. " | |
f"\nThe error stack is reproduced here:\n{e}" | |
) | |
concrete_type_builder.add_failed_attribute(name, hint) | |
pass | |
continue | |
# Handle calls to builtin functions (either bespoke builtins from torch.jit._builtins or | |
# a call to an aten function like torch.add) | |
builtin_symbol_name = _find_builtin(value) | |
if builtin_symbol_name: | |
concrete_type_builder.add_builtin_function(name, builtin_symbol_name) | |
continue | |
# Handle Script function attributes | |
if isinstance(value, torch.jit.ScriptFunction): | |
concrete_type_builder.add_function_attribute( | |
name, torch._C._jit_try_infer_type(value).type(), value | |
) | |
continue | |
# If we got here, this is a regular "data" attribute, add it to the concrete type | |
attr_type, inferred = infer_type(name, value) | |
if attr_type.success(): | |
concrete_type_builder.add_attribute(name, attr_type.type(), False, False) | |
else: | |
# TODO: could add more detail here. For example, what the user should do | |
# when the pytype is `list` or `NoneType` | |
inferred_msg = ( | |
"Its type was inferred; try adding a type annotation for the attribute." | |
if inferred | |
else "" | |
) | |
additional_info = f"{attr_type.reason()}. {inferred_msg}" | |
hint = ( | |
"(This attribute exists on the Python module, " | |
f"but we failed to convert Python type: '{torch.typename(type(value))}' " | |
f"to a TorchScript type. {additional_info})" | |
) | |
concrete_type_builder.add_failed_attribute(name, hint) | |
# add hooks to concrete type | |
for hook in nn_module._forward_hooks.values(): | |
concrete_type_builder.add_forward_hook(hook) | |
for pre_hook in nn_module._forward_pre_hooks.values(): | |
concrete_type_builder.add_forward_pre_hook(pre_hook) | |
return concrete_type_builder | |
class ConcreteTypeStore: | |
type_store: Dict[Type[Module], List[torch._C.ConcreteModuleType]] | |
methods_compiled: Set[torch._C.ConcreteModuleType] | |
def __init__(self): | |
# Python module type => List[ConcreteModuleType)] | |
self.type_store = {} | |
# ConcreteTypes that have had their methods already compiled | |
self.methods_compiled = set() | |
def get_or_create_concrete_type(self, nn_module): | |
"""Infer a ConcreteType from this `nn.Module` instance. Underlying JIT types are re-used if possible.""" | |
concrete_type_builder = infer_concrete_type_builder(nn_module) | |
nn_module_type = type(nn_module) | |
if nn_module_type not in self.type_store: | |
self.type_store[nn_module_type] = [] | |
# Search the type store for an already-available JIT type | |
known_types = self.type_store[nn_module_type] | |
for known_type in known_types: | |
if known_type.equals(concrete_type_builder): | |
return known_type | |
# We didn't find anything; generate a new JIT type from this concrete type | |
concrete_type = concrete_type_builder.build() | |
self.type_store[nn_module_type].append(concrete_type) | |
return concrete_type | |
concrete_type_store = ConcreteTypeStore() | |
def create_methods_and_properties_from_stubs( | |
concrete_type, method_stubs, property_stubs | |
): | |
method_defs = [m.def_ for m in method_stubs] | |
method_rcbs = [m.resolution_callback for m in method_stubs] | |
method_defaults = [get_default_args(m.original_method) for m in method_stubs] | |
property_defs = [p.def_ for p in property_stubs] | |
property_rcbs = [p.resolution_callback for p in property_stubs] | |
concrete_type._create_methods_and_properties( | |
property_defs, property_rcbs, method_defs, method_rcbs, method_defaults | |
) | |
def create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs): | |
hook_defs = [h.def_ for h in hook_stubs] | |
hook_rcbs = [h.resolution_callback for h in hook_stubs] | |
pre_hook_defs = [h.def_ for h in pre_hook_stubs] | |
pre_hook_rcbs = [h.resolution_callback for h in pre_hook_stubs] | |
concrete_type._create_hooks(hook_defs, hook_rcbs, pre_hook_defs, pre_hook_rcbs) | |
def get_module_concrete_type(nn_module, share_types=True): | |
""" | |
Get a concrete type for nn_modules. | |
If share_types is True, the concrete type is fetched from concrete_type_store. | |
If it is False, a new concrete type is created without first searching concrete_type_store. | |
Args: | |
nn_module: The original Python nn.Module that we are creating a ScriptModule for. | |
share_types = Whether to share underlying JIT types between modules (if possible). | |
Returns: | |
A concrete type for nn_module. | |
""" | |
assert isinstance(nn_module, Module) | |
if isinstance(nn_module, torch.jit.ScriptModule) and hasattr( | |
nn_module, "_concrete_type" | |
): | |
return nn_module._concrete_type | |
if share_types: | |
# Look into the store of cached JIT types | |
concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module) | |
else: | |
# Get a concrete type directly, without trying to re-use an existing JIT | |
# type from the type store. | |
concrete_type_builder = infer_concrete_type_builder(nn_module, share_types) | |
concrete_type_builder.set_poisoned() | |
concrete_type = concrete_type_builder.build() | |
return concrete_type | |
def create_script_class(obj): | |
""" | |
Create and return a RecursiveScriptClass instance from a Python object. | |
Arguments: | |
obj: A Python object. | |
""" | |
qualified_class_name = _jit_internal._qualified_name(type(obj)) | |
rcb = _jit_internal.createResolutionCallbackForClassMethods(type(obj)) | |
# Script the type of obj if it hasn't already been scripted. | |
_compile_and_register_class(type(obj), rcb, qualified_class_name) | |
class_ty = _python_cu.get_class(qualified_class_name) | |
# Create an empty torch._C.ScriptObject with the scripted type. | |
cpp_object = torch._C._create_object_with_type(class_ty) | |
# Copy all of the attributes over to the torch._C.ScriptObject. | |
for name, value in obj.__dict__.items(): | |
cpp_object.setattr(name, value) | |
# Wrap the torch._C.ScriptObject in a RecursiveScriptClass instance. | |
return wrap_cpp_class(cpp_object) | |
def create_script_module(nn_module, stubs_fn, share_types=True, is_tracing=False): | |
""" | |
Create a new ScriptModule from an nn.Module. | |
Args: | |
nn_module: The original Python nn.Module that we are creating a ScriptModule for. | |
stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile. | |
share_types: Whether to share underlying JIT types between modules (if possible). | |
NOTE: Only set to False this when we cannot guarantee type sharing will work | |
correctly. This only happens today for traced modules, where the same | |
module can produce different traced methods depending on the inputs. | |
is_tracing: Whether this function is called during tracing or scripting. If tracing, | |
we don't need to do AttributeTypeIsSupportedChecker because all the unsupported | |
attributes will be baked as constant in the tracing graph. In addition, | |
this check significantly slows down the traced modules when the module size is big. | |
""" | |
assert not isinstance(nn_module, torch.jit.RecursiveScriptModule) | |
check_module_initialized(nn_module) | |
concrete_type = get_module_concrete_type(nn_module, share_types) | |
if not is_tracing: | |
AttributeTypeIsSupportedChecker().check(nn_module) | |
return create_script_module_impl(nn_module, concrete_type, stubs_fn) | |
def create_script_module_impl(nn_module, concrete_type, stubs_fn): | |
""" | |
Convert an nn.Module to a RecursiveScriptModule. | |
Args: | |
nn_module: The original Python nn.Module that we are creating a ScriptModule for. | |
concrete_type: The fully initialized ConcreteType of the module. | |
stubs_fn: Lambda that takes an nn.Module and generates a list of ScriptMethodStubs to compile. | |
""" | |
cpp_module = torch._C._create_module_with_type(concrete_type.jit_type) | |
method_stubs = stubs_fn(nn_module) | |
property_stubs = get_property_stubs(nn_module) | |
hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module) | |
user_annotated_ignored_attributes = getattr( | |
nn_module, "__jit_ignored_attributes__", list() | |
) | |
ignored_properties = jit_ignored_properties(nn_module) | |
def init_fn(script_module): | |
# Initialize the ScriptModule: | |
# 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. | |
for name in concrete_type.get_attributes().keys(): | |
orig_value = getattr(nn_module, name) | |
orig_value = ( | |
orig_value.value | |
if isinstance(orig_value, torch.jit.Attribute) | |
else orig_value | |
) | |
cpp_module.setattr(name, orig_value) | |
# 2. Copy the submodules from the original `nn_module` to the new ScriptModule, | |
# recursively scripting them. | |
for name, sub_concrete_type in concrete_type.get_modules(): | |
orig_value = getattr(nn_module, name) | |
assert isinstance( | |
orig_value, Module | |
), f"Expected Module but got {type(orig_value)}" | |
module_type = sub_concrete_type.jit_type | |
if isinstance(module_type, torch._C.InterfaceType): | |
# use the interface inference rule to compile the module | |
scripted = interface_script(module_type, orig_value) | |
elif isinstance(orig_value, torch.jit.ScriptModule): | |
scripted = orig_value | |
else: | |
# always reuse the provided stubs_fn to infer the methods to compile | |
scripted = create_script_module_impl( | |
orig_value, sub_concrete_type, stubs_fn | |
) | |
cpp_module.setattr(name, scripted) | |
script_module._modules[name] = scripted | |
# 3. Copy @ignored/@unused methods and attrs from the original `nn_module` to the new ScriptModule. | |
# This ensures we can access these Python methods on the ScriptModule. | |
for name in dir(nn_module): | |
if name in ignored_properties: | |
continue | |
item = getattr(nn_module, name, None) | |
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item): | |
unbound_function = getattr(nn_module, name).__func__ | |
bound_method = unbound_function.__get__(script_module) | |
setattr(script_module, name, bound_method) | |
elif concrete_type.is_ignored_attribute(name): | |
setattr(script_module, name, item) | |
# For convenience, attach the concrete type to the new ScriptModule | |
script_module._concrete_type = concrete_type | |
# Actually create the ScriptModule, initializing it with the function we just defined | |
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) | |
# Compile methods if necessary | |
if concrete_type not in concrete_type_store.methods_compiled: | |
create_methods_and_properties_from_stubs( | |
concrete_type, method_stubs, property_stubs | |
) | |
# Create hooks after methods to ensure no name collisions between hooks and methods. | |
# If done before, hooks can overshadow methods that aren't exported. | |
create_hooks_from_stubs(concrete_type, hook_stubs, pre_hook_stubs) | |
torch._C._run_emit_module_hook(cpp_module) | |
concrete_type_store.methods_compiled.add(concrete_type) | |
# Copy the forward hooks and pre-hooks to the new ScriptModule | |
# to allow the hooks to be run from eager as ScriptFunctions | |
for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()): | |
script_module._forward_pre_hooks[idx] = fn | |
for idx, fn in enumerate(script_module._c._get_forward_hooks()): | |
script_module._forward_hooks[idx] = fn | |
# Special handling so methods like __len__ work in script methods on classes derived from containers | |
if ( | |
isinstance( | |
nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) | |
) | |
and "__len__" not in cpp_module._method_names() | |
): | |
script_module.define(f"def __len__(self):\n return {len(nn_module)}\n") | |
if ( | |
isinstance(nn_module, torch.nn.ModuleDict) | |
and "__contains__" not in cpp_module._method_names() | |
): | |
if len(nn_module.keys()): | |
keys = repr(list(nn_module.keys())) | |
script_module.define( | |
f"def __contains__(self, key: str):\n return key in {keys}\n" | |
) | |
else: | |
script_module.define("def __contains__(self, key: str):\n return False\n") | |
# Make the compiled methods available to the Python ScriptModule class. | |
for method_stub in method_stubs: | |
if method_stub.original_method is None: | |
# define()'d methods don't have an Python original_method, so we | |
# don't need to do any Python re-wrapping stuff | |
continue | |
name = method_stub.original_method.__name__ | |
if name != method_stub.def_.name().name: | |
# TODO: Why skip this? Because @torch.jit._overload_method will | |
# mangle the name of the function. | |
continue | |
script_method = cpp_module._get_method(name) | |
# Wrap the original to propagate docstrings and such. | |
# TODO: we don't currently do this functions that are recursively | |
# compiled, we should. | |
wrapped_script_method = functools.wraps(method_stub.original_method)( | |
script_method | |
) | |
# Add the methods to the script_module directly. This ensures they will | |
# be found first when `name` is looked up (as opposed to the stubs or | |
# nn.Module.forward) | |
script_module.__dict__[name] = wrapped_script_method | |
# Make module properties available on the Python ScriptModule class. | |
for property_stub in property_stubs: | |
property_name = property_stub.def_.name().name | |
fget = cpp_module._get_method(property_stub.def_.getter_name().name) | |
# Setter is optional, so it may not exist. | |
setter_name = property_stub.def_.setter_name() | |
fset = cpp_module._get_method(setter_name.name) if setter_name else None | |
script_module.__dict__[property_name] = property(property_name, fget, fset) # type: ignore[arg-type] | |
# copy over python methods to script module if they aren't defined on the script module | |
# this is currently an internal api used only on module containers | |
for name in dir(nn_module): | |
if name in ignored_properties: | |
continue | |
item = getattr(nn_module, name, None) | |
if ( | |
_jit_internal.get_torchscript_modifier(item) | |
is _jit_internal.FunctionModifiers.COPY_TO_SCRIPT_WRAPPER | |
): | |
add_python_attr_to_scripted_model(script_module, nn_module, name) | |
return script_module | |
# We define shims of certain attributes on the RecursiveScriptModule to support | |
# magic methods. To check if a script model defines an attribute we need | |
# to also check that the attribute is not the shim | |
def script_model_defines_attr(script_model, attr): | |
script_attr = getattr(script_model, attr, None) | |
if script_attr is None: | |
return False | |
default_attr = getattr(torch.jit.RecursiveScriptModule, attr, None) | |
if default_attr is None: | |
return False | |
return script_attr != default_attr | |
def add_python_attr_to_scripted_model(script_model, orig, attr): | |
if hasattr(orig, attr) and script_model_defines_attr(script_model, attr): | |
setattr(script_model, attr, getattr(orig, attr)) | |
def get_overload_annotations(mod, jit_ignored_properties): | |
# original function => [(mangled overload name, overload function)] | |
overloads = {} | |
for name in dir(type(mod)): | |
if name in jit_ignored_properties: | |
continue | |
item = getattr(mod, name, None) | |
if not callable(item): | |
continue | |
# builtin functions like repr() in python 2 do not have __module__ defined | |
if hasattr(item, "__module__") and item.__module__ is not None: | |
method_overloads = _jit_internal._get_overloaded_methods( | |
item, mod.__class__ | |
) | |
if method_overloads is None: | |
continue | |
if item.__func__ in method_overloads: | |
raise RuntimeError( | |
_jit_internal.get_overload_no_implementation_error_message( | |
"method", item.__func__ | |
) | |
) | |
names = [name + "__" + str(i) for i in range(len(method_overloads))] | |
overloads[item] = list(zip(names, method_overloads)) | |
return overloads | |
def get_overload_name_mapping(overload_info): | |
# Same format as __overloads__ | |
# original function => [overload names] | |
overload_name_mappings: Dict[str, List[str]] = {} | |
for orig_fn, overloads in overload_info.items(): | |
original_name = orig_fn.__name__ | |
if original_name not in overload_name_mappings: | |
overload_name_mappings[original_name] = [] | |
for overload_name, _ in overloads: | |
overload_name_mappings[original_name].append(overload_name) | |
return overload_name_mappings | |
def _check_no_signature(func): | |
signature = torch.jit.annotations.get_signature( | |
func, None, fake_range(), inspect.ismethod(func) | |
) | |
if signature is None: | |
qual_name = _jit_internal._qualified_name(func) | |
raise RuntimeError( | |
f"Must explicitly add type annotations to overloaded functions: {qual_name}" | |
) | |
def make_stubs_for_overloads(overload_info): | |
overload_stubs = [] | |
for orig_fn, overloads in overload_info.items(): | |
orig_ast = get_jit_def( | |
orig_fn, orig_fn.__name__, self_name="RecursiveScriptModule" | |
) | |
for overload_name, overload_fn in overloads: | |
_check_no_signature(overload_fn) | |
over_ast = get_jit_def( | |
overload_fn, overload_fn.__name__, self_name="RecursiveScriptModule" | |
) | |
new_ast = torch._C._replace_overloaded_method_decl( | |
over_ast.decl(), orig_ast, overload_name | |
) | |
_rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn) | |
overload_stubs.append(ScriptMethodStub(_rcb, new_ast, overload_fn)) | |
return overload_stubs | |
def check_module_initialized(mod): | |
assert isinstance(mod, torch.nn.Module) | |
if not hasattr(mod, "_parameters"): | |
raise RuntimeError( | |
f"'{torch.typename(type(mod))}' has not been initialized, did you forget to call 'super()'?" | |
) | |
# This is to avoid importing torch.distributed.nn | |
if not hasattr(mod, "remote_parameters"): | |
for name, param in mod._parameters.items(): | |
if param is not None and torch.nn.parameter.is_lazy(param): | |
raise RuntimeError( | |
"'{}' has uninitialized parameters {}. Did you forget to run a forward pass?".format( | |
torch.typename(type(mod)), name | |
) | |
) | |
for name, buf in mod._buffers.items(): | |
if buf is not None and torch.nn.parameter.is_lazy(buf): | |
raise RuntimeError( | |
"'{}' has uninitialized buffers {}. Did you forget to run a forward pass?".format( | |
torch.typename(type(mod)), name | |
) | |
) | |
def infer_methods_to_compile(nn_module): | |
"""Implement the default rules for which methods should act as starting points for compilation. | |
(TODO add a link when the rules are published). | |
""" | |
check_module_initialized(nn_module) | |
user_annotated_ignored_attributes = getattr( | |
nn_module, "__jit_ignored_attributes__", list() | |
) | |
ignored_properties = jit_ignored_properties(nn_module) | |
methods: List[str] = [] | |
if hasattr(nn_module, "forward") and not _jit_internal.is_ignored_fn( | |
nn_module.forward | |
): | |
forward_func = getattr(nn_module.forward, "__func__", None) | |
module_forward = getattr(torch.nn.Module, "forward", None) | |
if forward_func != module_forward: | |
methods = ["forward"] | |
exported = [] | |
for name in dir(nn_module): | |
if name in ignored_properties: | |
continue | |
item = getattr(nn_module, name, None) | |
if ( | |
_jit_internal.get_torchscript_modifier(item) | |
is _jit_internal.FunctionModifiers.EXPORT | |
): | |
exported.append(name) | |
methods = methods + exported | |
overload_name_mappings = dict(getattr(nn_module, "__overloads__", {})) | |
overload_info = get_overload_annotations(nn_module, ignored_properties) | |
overload_name_mappings.update(get_overload_name_mapping(overload_info)) | |
overload_stubs = make_stubs_for_overloads(overload_info) | |
nn_module.__overloads__ = overload_name_mappings | |
# we shouldn't directly compile overloaded methods, just its overloads | |
def ignore_overloaded(method_name): | |
return method_name not in overload_name_mappings | |
filtered_methods = filter(ignore_overloaded, methods) | |
# Unique the methods. We don't want to use a set to store the methods because it | |
# introduces non-determinism to compile order. | |
uniquer: Set[str] = set() | |
uniqued_methods = [] | |
for name in filtered_methods: | |
if name in uniquer: | |
continue | |
uniqued_methods.append(name) | |
uniquer.add(name) | |
stubs = [] | |
for method in uniqued_methods: | |
stubs.append(make_stub_from_method(nn_module, method)) | |
return overload_stubs + stubs | |
def get_hook_stubs(nn_module): | |
"""Return forward hook and pre_hook ScriptModuleStubs.""" | |
check_module_initialized(nn_module) | |
hook_map: Dict = {} | |
hook_stubs = [] | |
for hook in nn_module._forward_hooks.values(): | |
if hook.__name__ in hook_map: | |
if id(hook) != id(hook_map[hook.__name__]): | |
raise RuntimeError( | |
f"Hook '{hook.__name__}' on {type(nn_module).__name__} " | |
"has at least two different python definitions." | |
" Please use unique names for all hooks." | |
) | |
else: | |
hook_map[hook.__name__] = hook | |
hook_stubs.append(make_stub(hook, hook.__name__)) | |
pre_hook_stubs = [] | |
for pre_hook in nn_module._forward_pre_hooks.values(): | |
if pre_hook.__name__ in hook_map: | |
if id(pre_hook) != id(hook_map[pre_hook.__name__]): | |
raise RuntimeError( | |
f"Pre-hook '{pre_hook.__name__}' on {type(nn_module).__name__} " | |
"has at least two different python definitions." | |
" Please use unique names for all hooks." | |
) | |
else: | |
hook_map[pre_hook.__name__] = pre_hook | |
pre_hook_stubs.append(make_stub(pre_hook, pre_hook.__name__)) | |
return hook_stubs, pre_hook_stubs | |
def get_property_stubs(nn_module): | |
"""Create property stubs for the properties of the module by creating method stubs for the getter and setter.""" | |
module_ty = type(nn_module) | |
properties_asts = get_class_properties(module_ty, self_name="RecursiveScriptModule") | |
rcbs = {} | |
for name in dir(module_ty): | |
item = getattr(module_ty, name, None) | |
if isinstance(item, property): | |
if not item.fget: | |
raise RuntimeError( | |
f"Property {name} of {nn_module.__name__} must have a getter" | |
) | |
rcbs[name] = _jit_internal.createResolutionCallbackFromClosure(item.fget) | |
stubs = [PropertyStub(rcbs[ast.name().name], ast) for ast in properties_asts] | |
return stubs | |
def interface_script(mod_interface, nn_module): | |
""" | |
Make a ScriptModule from an nn.Module, using the interface methods rule for determining which methods to compile. | |
Args: | |
mod_interface: the interface type that the module have | |
nn_module: The original Python nn.Module that we are creating a ScriptModule for. | |
""" | |
if isinstance(nn_module, torch.jit.ScriptModule): | |
return nn_module | |
check_module_initialized(nn_module) | |
def infer_interface_methods_to_compile(nn_module): | |
"""Rule to infer the methods from the interface type. | |
It is used to know which methods need to act as starting points for compilation. | |
""" | |
stubs = [] | |
for method in mod_interface.getMethodNames(): | |
stubs.append(make_stub_from_method(nn_module, method)) | |
return stubs | |
return create_script_module(nn_module, infer_interface_methods_to_compile) | |
def try_compile_fn(fn, loc): | |
if _jit_internal.is_ignored_fn(fn): | |
# Don't do anything for @ignore'd functions | |
return None | |
if isinstance(fn, torch.nn.Module): | |
# Since modules are callable pybind recognizes them as functions, but | |
# don't do anything for them | |
return None | |
if not inspect.isfunction(fn) and not inspect.ismethod(fn): | |
raise RuntimeError( | |
f"`{fn}` is not a function. Recursive scripting only supports " | |
"Python functions or methods currently.\n" | |
f"Consider manually annotating `{fn}` with @torch.jit.script." | |
) | |
# The object returned by __prepare_scriptable__ might have a different closure. | |
# Resolve it here to get the right resolution callback. | |
fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn # type: ignore[operator] | |
# We don't have the actual scope where the function was defined, but we can | |
# extract the necessary info from the closed over variables on the function | |
# object | |
rcb = _jit_internal.createResolutionCallbackFromClosure(fn) | |
return torch.jit.script(fn, _rcb=rcb) | |
def wrap_cpp_class(cpp_class): | |
"""Wrap this torch._C.Object in a Python RecursiveScriptClass.""" | |
return torch.jit.RecursiveScriptClass(cpp_class) | |
def wrap_cpp_module(cpp_module): | |
"""Wrap this torch._C.ScriptModule in a Python ScriptModule, recursively for all submodules.""" | |
def init_fn(script_module): | |
for name, cpp_module in torch._C.ModuleDict(script_module._c).items(): | |
setattr(script_module, name, wrap_cpp_module(cpp_module)) | |
script_module._concrete_type = torch._C.ConcreteModuleType.from_jit_type( | |
script_module._c._type() | |
) | |
for idx, fn in enumerate(script_module._c._get_forward_pre_hooks()): | |
script_module._forward_pre_hooks[idx] = fn | |
for idx, fn in enumerate(script_module._c._get_forward_hooks()): | |
script_module._forward_hooks[idx] = fn | |
return torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) | |
def compile_unbound_method(concrete_type, fn): | |
if _jit_internal.is_ignored_fn(fn): | |
return None | |
stub = make_stub(fn, fn.__name__) | |
with torch._jit_internal._disable_emit_hooks(): | |
# We don't want to call the hooks here since the graph that is calling | |
# this function is not yet complete | |
create_methods_and_properties_from_stubs(concrete_type, (stub,), ()) | |
return stub | |
def lazy_bind(concrete_type, unbound_method): | |
""" | |
Return a function that lazily binds `unbound_method` to a provided Module IValue, then invokes the method. | |
We do this so that any Python shenanigans that | |
will poison type sharing are impossible at compile time. | |
""" | |
def lazy_binding_method(cpp_module, *args): | |
def init_fn(script_module): | |
orig_class = concrete_type.py_class | |
# Copy @ignored/@unused methods from the original module to the new one. | |
# This ensures they are available during execution. | |
for name in dir(orig_class): | |
item = getattr(orig_class, name, None) | |
if _jit_internal.is_ignored_fn(item): | |
setattr(script_module, name, item) | |
# Copy constants over so they are available during execution. | |
for name, value in concrete_type.get_constants().items(): | |
setattr(script_module, name, value) | |
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) | |
method = types.MethodType(unbound_method, script_module) | |
return method(*args) | |
# make the lazy binding method "look like" the original method | |
lazy_binding_method.original_fn = unbound_method # type: ignore[attr-defined] | |
lazy_binding_method.__name__ = unbound_method.__name__ | |
torch._jit_internal.copy_torchscript_modifier(unbound_method, lazy_binding_method) | |
return lazy_binding_method | |