File size: 8,196 Bytes
d1ed09d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""
Utilities for manipulating the Abstract Syntax Tree of Python constructs
"""
import re
import copy
import inspect
import ast
import textwrap


class NameVisitor(ast.NodeVisitor):
    """
    NodeVisitor that builds a set of all of the named identifiers in an AST
    """
    def __init__(self, *args, **kwargs):
        super(NameVisitor, self).__init__(*args, **kwargs)
        self.names = set()

    def visit_Name(self, node):
        self.names.add(node.id)

    def visit_arg(self, node):
        if hasattr(node, 'arg'):
            self.names.add(node.arg)
        elif hasattr(node, 'id'):
            self.names.add(node.id)

    def get_new_names(self, num_names):
        """
        Returns a list of new names that are not already present in the AST.

        New names will have the form _N, for N a non-negative integer. If the
        AST has no existing identifiers of this form, then the returned names
        will start at 0 ('_0', '_1', '_2'). If the AST already has identifiers
        of this form, then the names returned will not include the existing
        identifiers.

        Parameters
        ----------
        num_names: int
            The number of new names to return

        Returns
        -------
        list of str
        """
        prop_re = re.compile(r"^_(\d+)$")
        matching_names = [n for n in self.names if prop_re.match(n)]
        if matching_names:
            start_number = max([int(n[1:]) for n in matching_names]) + 1
        else:
            start_number = 0

        return ["_" + str(n) for n in
                range(start_number, start_number + num_names)]


class ExpandVarargTransformer(ast.NodeTransformer):
    """
    Node transformer that replaces the starred use of a variable in an AST
    with a collection of unstarred named variables.
    """
    def __init__(self, starred_name, expand_names, *args, **kwargs):
        """
        Parameters
        ----------
        starred_name: str
            The name of the starred variable to replace
        expand_names: list of stf
            List of the new names that should be used to replace the starred
            variable

        """
        super(ExpandVarargTransformer, self).__init__(*args, **kwargs)
        self.starred_name = starred_name
        self.expand_names = expand_names


class ExpandVarargTransformerStarred(ExpandVarargTransformer):
    # Python 3
    def visit_Starred(self, node):
        if node.value.id == self.starred_name:
            return [ast.Name(id=name, ctx=node.ctx) for name in
                    self.expand_names]
        else:
            return node


def function_to_ast(fn):
    """
    Get the AST representation of a function
    """
    # Get source code for function
    # Dedent is needed if this is a nested function
    fn_source = textwrap.dedent(inspect.getsource(fn))

    # Parse function source code into an AST
    fn_ast = ast.parse(fn_source)

    # # The function will be the fist element of the module body
    # fn_ast = module_ast.body[0]

    return fn_ast


def ast_to_source(ast):
    """Convert AST to source code string using the astor package"""
    import astor
    return astor.to_source(ast)


def compile_function_ast(fn_ast):
    """
    Compile function AST into a code object suitable for use in eval/exec
    """
    assert isinstance(fn_ast, ast.Module)
    fndef_ast = fn_ast.body[0]
    assert isinstance(fndef_ast, ast.FunctionDef)
    return compile(fn_ast, "<%s>" % fndef_ast.name, mode='exec')


def function_ast_to_function(fn_ast, stacklevel=1):
    # Validate
    assert isinstance(fn_ast, ast.Module)
    fndef_ast = fn_ast.body[0]
    assert isinstance(fndef_ast, ast.FunctionDef)

    # Compile AST to code object
    code = compile_function_ast(fn_ast)

    # Evaluate the function in a scope that includes the globals and
    # locals of desired frame.
    current_frame = inspect.currentframe()
    eval_frame = current_frame
    for _ in range(stacklevel):
        eval_frame = eval_frame.f_back

    eval_locals = eval_frame.f_locals
    eval_globals = eval_frame.f_globals
    del current_frame
    scope = copy.copy(eval_globals)
    scope.update(eval_locals)

    # Evaluate function in scope
    eval(code, scope)

    # Return the newly evaluated function from the scope
    return scope[fndef_ast.name]


def _build_arg(name):
    return ast.arg(arg=name)


def expand_function_ast_varargs(fn_ast, expand_number):
    """
    Given a function AST that use a variable length positional argument
    (e.g. *args), return a function that replaces the use of this argument
    with one or more fixed arguments.

    To be supported, a function must have a starred argument in the function
    signature, and it may only use this argument in starred form as the
    input to other functions.

    For example, suppose expand_number is 3 and fn_ast is an AST
    representing this function...

    def my_fn1(a, b, *args):
        print(a, b)
        other_fn(a, b, *args)

    Then this function will return the AST of a function equivalent to...

    def my_fn1(a, b, _0, _1, _2):
        print(a, b)
        other_fn(a, b, _0, _1, _2)

    If the input function uses `args` for anything other than passing it to
    other functions in starred form, an error will be raised.

    Parameters
    ----------
    fn_ast: ast.FunctionDef
    expand_number: int

    Returns
    -------
    ast.FunctionDef
    """
    assert isinstance(fn_ast, ast.Module)

    # Copy ast so we don't modify the input
    fn_ast = copy.deepcopy(fn_ast)

    # Extract function definition
    fndef_ast = fn_ast.body[0]
    assert isinstance(fndef_ast, ast.FunctionDef)

    # Get function args
    fn_args = fndef_ast.args

    # Function variable arity argument
    fn_vararg = fn_args.vararg

    # Require vararg
    if not fn_vararg:
        raise ValueError("""\
Input function AST does not have a variable length positional argument
(e.g. *args) in the function signature""")
    assert fn_vararg

    # Get vararg name
    if isinstance(fn_vararg, str):
        vararg_name = fn_vararg
    else:
        vararg_name = fn_vararg.arg

    # Compute new unique names to use in place of the variable argument
    before_name_visitor = NameVisitor()
    before_name_visitor.visit(fn_ast)
    expand_names = before_name_visitor.get_new_names(expand_number)

    # Replace use of *args in function body
    expand_transformer = ExpandVarargTransformerStarred

    new_fn_ast = expand_transformer(
        vararg_name, expand_names
    ).visit(fn_ast)

    new_fndef_ast = new_fn_ast.body[0]

    # Replace vararg with additional args in function signature
    new_fndef_ast.args.args.extend(
        [_build_arg(name=name) for name in expand_names]
    )
    new_fndef_ast.args.vararg = None

    # Run a new NameVistor an see if there were any other non-starred uses
    # of the variable length argument. If so, raise an exception
    after_name_visitor = NameVisitor()
    after_name_visitor.visit(new_fn_ast)
    if vararg_name in after_name_visitor.names:
        raise ValueError("""\
The variable length positional argument {n} is used in an unsupported context
""".format(n=vararg_name))

    # Remove decorators if present to avoid recursion
    fndef_ast.decorator_list = []

    # Add missing source code locations
    ast.fix_missing_locations(new_fn_ast)

    # Return result
    return new_fn_ast


def expand_varargs(expand_number):
    """
    Decorator to expand the variable length (starred) argument in a function
    signature with a fixed number of arguments.

    Parameters
    ----------
    expand_number: int
        The number of fixed arguments that should replace the variable length
        argument

    Returns
    -------
    function
        Decorator Function
    """
    if not isinstance(expand_number, int) or expand_number < 0:
        raise ValueError("expand_number must be a non-negative integer")

    def _expand_varargs(fn):
        fn_ast = function_to_ast(fn)
        fn_expanded_ast = expand_function_ast_varargs(fn_ast, expand_number)
        return function_ast_to_function(fn_expanded_ast, stacklevel=2)
    return _expand_varargs