Spaces:
Sleeping
Sleeping
import ast | |
import inspect | |
import textwrap | |
import copy | |
import functools | |
from types import FunctionType | |
from typing import cast, Union, Callable, Dict, Optional, Any | |
from torch.fx._symbolic_trace import Tracer | |
from torch.fx.graph import Graph | |
from torch._sources import normalize_source_lines | |
import torch | |
class AST_Rewriter(ast.NodeTransformer): | |
""" | |
Take a FunctionType object representing a `forward` method, then | |
perform an AST rewrite to swap out nodes that are not symbolically | |
traceable with a callsite to the FX alternative. | |
To support swapping out an AST node, define a new `visit` method on | |
that node. For more details, see: | |
https://docs.python.org/3/library/ast.html#ast.NodeTransformer | |
""" | |
def rewrite(self, fn: FunctionType): | |
# Normalize the source lines | |
sourcelines, _ = inspect.getsourcelines(fn) | |
sourcelines = normalize_source_lines(sourcelines) | |
source = ''.join(sourcelines) | |
normalized_str = textwrap.dedent(source) | |
# Rewrite the original AST | |
source_ast = ast.parse(normalized_str) | |
dest_ast = ast.fix_missing_locations(self.visit(source_ast)) | |
# Pull out the compiled function from the newly-created Module | |
code = compile(dest_ast, "", "exec") | |
globals_dict = copy.copy(fn.__globals__) | |
keys_before = set(globals_dict.keys()) | |
exec(code, globals_dict) | |
new_keys = list(set(globals_dict.keys()) - keys_before) | |
assert len(new_keys) == 1 | |
fn_compiled = globals_dict[new_keys[0]] | |
# return the compiled function with the original globals | |
def change_func_globals(f, globals): | |
"""Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" | |
# __globals__ is a private member of the function class | |
# so we have to copy the function, f, all of its member, except f.__globals__ | |
g = FunctionType( | |
f.__code__, | |
globals, | |
name=f.__name__, | |
argdefs=f.__defaults__, | |
closure=f.__closure__, | |
) | |
g = functools.update_wrapper(g, f) | |
g.__kwdefaults__ = copy.copy(f.__kwdefaults__) | |
return g | |
# Return the correct FunctionType object | |
return change_func_globals(fn_compiled, globals=fn.__globals__) | |
def visit_Assert(self, node): | |
""" | |
Swap out the Assert node (Python's `assert`) with a callsite to the | |
symbolically-traceable torch._assert function | |
""" | |
# Create the Call node | |
n = ast.parse('torch._assert()', mode='eval') | |
assert isinstance(n, ast.Expression) | |
call_node = n.body | |
assert isinstance(call_node, ast.Call) | |
msg = node.msg if node.msg else ast.Constant(value="", kind=None) | |
call_node.args = [node.test, msg] | |
# Ensure that the new node conforms to the Python AST grammar | |
expr_wrapper = ast.Expr(value=call_node) | |
# Return the new Call node to signify that we want to use it as | |
# a replacement for the original _assert node | |
return ast.copy_location(expr_wrapper, node) | |
def visit_AnnAssign(self, node): | |
""" | |
Swap out Python's AnnAssign with an Assign node where the annotation function is called. | |
Example: | |
Original: | |
y: Tensor_Type(1,2,3, Dyn) = f2(x) | |
Output: | |
y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) | |
""" | |
return ast.Assign(targets=[node.target], value=ast.Call( | |
func=ast.Name(id='annotate', ctx=ast.Load()), | |
args=[node.value, node.annotation], keywords=[])) | |
class RewritingTracer(Tracer): | |
def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: | |
return super().trace(_rewrite(root), concrete_args) | |
def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: | |
if isinstance(fn, torch.nn.Module): | |
# Rewrite this module's `forward` as well as the `forward`s of | |
# all of this module's recursive descendents. Return the new, | |
# rewritten module hierarchy. | |
def rewrite_module(m : torch.nn.Module): | |
class RewrittenModule(torch.nn.Module): | |
def __init__(self, orig): | |
super().__init__() | |
for k, v in orig.__dict__.items(): | |
if isinstance(v, torch.nn.Module): | |
self.__dict__[k] = copy.copy(rewrite_module(v)) | |
else: | |
self.__dict__[k] = copy.copy(v) | |
RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward)) | |
return RewrittenModule(m) | |
return rewrite_module(fn) | |
else: | |
# Rewrite this single free function | |
return AST_Rewriter().rewrite(cast(FunctionType, fn)) | |