Spaces:
Running
Running
import operator | |
from typing import Any, Callable, Dict, Tuple, Optional | |
import torch | |
import torch.fx | |
import torch.fx as fx | |
from torch.fx import Transformer, Proxy | |
from torch.fx.node import Argument, Target, Node, map_aggregate | |
from torch.fx.operator_schemas import ( | |
normalize_module, | |
normalize_function, | |
create_type_hint, | |
) | |
from .schema_type_annotation import AnnotateTypesWithSchema | |
class NormalizeArgs(Transformer): | |
""" | |
Normalize arguments to Python targets. This means that | |
`args/kwargs` will be matched up to the module/functional's | |
signature and rewritten to exclusively kwargs in positional order | |
if `normalize_to_only_use_kwargs` is true. Also populates default | |
values. Does not support positional-only parameters or varargs | |
parameters (*args, **kwargs). | |
If the nodes have 'type' metadata, it will use it to disambiguate | |
overloads. Otherwise, it will throw an error. | |
Example usage: | |
m = torchvision.models.resnet18() | |
traced = torch.fx.symbolic_trace(m) | |
traced = NormalizeArgs(traced).transform() | |
""" | |
def __init__( | |
self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True | |
): | |
super().__init__(module) | |
self.node_map: Dict[Proxy, Node] = {} | |
self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs | |
def run_node(self, n: Node) -> Any: | |
args, kwargs = self.fetch_args_kwargs_from_env(n) | |
def get_type(arg): | |
if isinstance(arg, fx.Node): | |
return n.meta["type"] if "type" in n.meta else None | |
return type(arg) | |
arg_types = map_aggregate(n.args, get_type) | |
assert isinstance(arg_types, tuple) | |
arg_types = tuple([create_type_hint(i) for i in arg_types]) | |
kwarg_types = {k: get_type(v) for k, v in kwargs.items()} | |
if n.op == "call_function": | |
out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) | |
else: | |
out = super().run_node(n) | |
if n.op != "output": | |
self.node_map[out] = n | |
out.node.meta = n.meta | |
out.node.type = n.type | |
return out | |
def call_function( | |
self, | |
target: Target, | |
args: Tuple[Argument, ...], | |
kwargs: Dict[str, Any], | |
arg_types: Optional[Tuple[Any, ...]] = None, | |
kwarg_types: Optional[Dict[str, Any]] = None, | |
): | |
assert callable(target) | |
new_args_and_kwargs = normalize_function( | |
target, | |
args, # type: ignore[arg-type] | |
kwargs, | |
arg_types, # type: ignore[arg-type] | |
kwarg_types, | |
self.normalize_to_only_use_kwargs, | |
) | |
if new_args_and_kwargs: | |
new_args, new_kwargs = new_args_and_kwargs | |
return self.tracer.create_proxy( | |
"call_function", target, new_args, new_kwargs | |
) | |
else: | |
return super().call_function(target, args, kwargs) | |
def call_module( | |
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] | |
): | |
assert isinstance(target, str) | |
new_args_and_kwargs = normalize_module( | |
self.module, | |
target, | |
args, # type: ignore[arg-type] | |
kwargs, | |
self.normalize_to_only_use_kwargs, | |
) | |
if new_args_and_kwargs: | |
new_args, new_kwargs = new_args_and_kwargs | |
return super().call_module(target, new_args, new_kwargs) | |
else: | |
return super().call_module(target, args, kwargs) | |
class NormalizeOperators(AnnotateTypesWithSchema): | |
""" | |
Normalize callsites that are different ways of "spelling" the same | |
invocation into a single, canonical call. Currently supports: | |
1. Normalize operators (e.g. operator.add) to the `torch` ops they | |
ultimately invoke (e.g. torch.add) when it is possible to statically | |
reason that | |
Example usage: | |
m = torchvision.models.resnet18() | |
traced = torch.fx.symbolic_trace(m) | |
traced = NormalizeOperators(traced).transform() | |
""" | |
binary_magic_method_remap: Dict[ | |
Callable[[Any, Any], Any], Callable[[Any, Any], Any] | |
] = { | |
torch.add: operator.add, | |
torch.mul: operator.mul, | |
torch.sub: operator.sub, | |
torch.div: operator.truediv, | |
torch.floor_divide: operator.floordiv, | |
torch.remainder: operator.mod, | |
torch.eq: operator.eq, | |
torch.ne: operator.ne, | |
torch.lt: operator.lt, | |
torch.le: operator.le, | |
torch.gt: operator.gt, | |
torch.ge: operator.ge, | |
} | |
def call_function( | |
self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] | |
): | |
# Normalize operators according to the magic methods implemented on tensors here: | |
# https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 | |
assert callable(target) | |
if target in self.binary_magic_method_remap: | |
if len(args) != 2: | |
return super().call_function(target, args, kwargs) | |
lhs, rhs = args | |
return super().call_function( | |
target=self.binary_magic_method_remap[target], | |
args=(lhs, rhs), | |
kwargs={}, | |
) | |
return super().call_function(target, args, kwargs) | |