Spaces:
Running
Running
from contextlib import contextmanager | |
import torch | |
import torch._custom_ops | |
from torch._C import DispatchKey | |
from torch._higher_order_ops.strict_mode import strict_mode | |
from torch._higher_order_ops.utils import autograd_not_implemented | |
from torch._ops import HigherOrderOperator | |
from torch._subclasses.fake_tensor import FakeTensorMode | |
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree | |
from torch.utils import _pytree as pytree | |
_export_tracepoint = HigherOrderOperator("_export_tracepoint") | |
def export_tracepoint_dispatch_mode(mode, *args, **kwargs): | |
if not mode.enable_tracing: | |
return _export_tracepoint(*args, **kwargs) | |
p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs)) | |
proxy = mode.tracer.create_proxy( | |
"call_function", _export_tracepoint, p_args, p_kwargs | |
) | |
return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer) | |
def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs): | |
with mode: | |
return args | |
def export_tracepoint_functional(ctx, *args, **kwargs): | |
unwrapped_args = ctx.unwrap_tensors(args) | |
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) | |
with ctx.redispatch_to_next(): | |
out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs) | |
return ctx.wrap_tensors(out) | |
_export_tracepoint.py_impl(DispatchKey.Autograd)( | |
autograd_not_implemented(_export_tracepoint, deferred_error=True) | |
) | |
def export_tracepoint_cpu(*args, **kwargs): | |
return args | |
def _wrap_submodule(mod, path, module_call_specs): | |
assert isinstance(mod, torch.nn.Module) | |
assert path != "" | |
submodule = mod | |
for name in path.split("."): | |
if not hasattr(submodule, name): | |
raise RuntimeError(f"Couldn't find submodule at path {path}") | |
submodule = getattr(submodule, name) | |
def update_module_call_signatures(path, in_spec, out_spec): | |
if path in module_call_specs: | |
assert module_call_specs[path]["in_spec"] == in_spec | |
assert module_call_specs[path]["out_spec"] == out_spec | |
module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec} | |
def check_flattened(flat_args): | |
for a in flat_args: | |
if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None): | |
raise AssertionError( | |
f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}" | |
) | |
def pre_hook(module, args, kwargs): | |
flat_args, in_spec = pytree.tree_flatten((args, kwargs)) | |
check_flattened(flat_args) | |
flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path) | |
args, kwargs = pytree.tree_unflatten(flat_args, in_spec) | |
return args, kwargs | |
def post_hook(module, args, kwargs, res): | |
_, in_spec = pytree.tree_flatten((args, kwargs)) | |
flat_res, out_spec = pytree.tree_flatten(res) | |
check_flattened(flat_res) | |
flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path) | |
update_module_call_signatures(path, in_spec, out_spec) | |
return pytree.tree_unflatten(flat_res, out_spec) | |
pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True) | |
post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True) | |
return pre_handle, post_handle | |
def _wrap_submodules(f, preserve_signature, module_call_signatures): | |
handles = [] | |
try: | |
for path in preserve_signature: | |
handles.extend(_wrap_submodule(f, path, module_call_signatures)) | |
yield | |
finally: | |
for handle in handles: | |
handle.remove() | |
def _mark_strict_experimental(cls): | |
def call(self, *args): | |
return strict_mode(self, args) | |
cls.__call__ = call | |
return cls | |