Spaces:
Running
on
Zero
Running
on
Zero
| from collections import deque | |
| from dask.core import istask, subs | |
| def head(task): | |
| """Return the top level node of a task""" | |
| if istask(task): | |
| return task[0] | |
| elif isinstance(task, list): | |
| return list | |
| else: | |
| return task | |
| def args(task): | |
| """Get the arguments for the current task""" | |
| if istask(task): | |
| return task[1:] | |
| elif isinstance(task, list): | |
| return task | |
| else: | |
| return () | |
| class Traverser: | |
| """Traverser interface for tasks. | |
| Class for storing the state while performing a preorder-traversal of a | |
| task. | |
| Parameters | |
| ---------- | |
| term : task | |
| The task to be traversed | |
| Attributes | |
| ---------- | |
| term | |
| The current element in the traversal | |
| current | |
| The head of the current element in the traversal. This is simply `head` | |
| applied to the attribute `term`. | |
| """ | |
| def __init__(self, term, stack=None): | |
| self.term = term | |
| if not stack: | |
| self._stack = deque([END]) | |
| else: | |
| self._stack = stack | |
| def __iter__(self): | |
| while self.current is not END: | |
| yield self.current | |
| self.next() | |
| def copy(self): | |
| """Copy the traverser in its current state. | |
| This allows the traversal to be pushed onto a stack, for easy | |
| backtracking.""" | |
| return Traverser(self.term, deque(self._stack)) | |
| def next(self): | |
| """Proceed to the next term in the preorder traversal.""" | |
| subterms = args(self.term) | |
| if not subterms: | |
| # No subterms, pop off stack | |
| self.term = self._stack.pop() | |
| else: | |
| self.term = subterms[0] | |
| self._stack.extend(reversed(subterms[1:])) | |
| def current(self): | |
| return head(self.term) | |
| def skip(self): | |
| """Skip over all subterms of the current level in the traversal""" | |
| self.term = self._stack.pop() | |
| class Token: | |
| """A token object. | |
| Used to express certain objects in the traversal of a task or pattern.""" | |
| def __init__(self, name): | |
| self.name = name | |
| def __repr__(self): | |
| return self.name | |
| # A variable to represent *all* variables in a discrimination net | |
| VAR = Token("?") | |
| # Represents the end of the traversal of an expression. We can't use `None`, | |
| # 'False', etc... here, as anything may be an argument to a function. | |
| END = Token("end") | |
| class Node(tuple): | |
| """A Discrimination Net node.""" | |
| __slots__ = () | |
| def __new__(cls, edges=None, patterns=None): | |
| edges = edges if edges else {} | |
| patterns = patterns if patterns else [] | |
| return tuple.__new__(cls, (edges, patterns)) | |
| def edges(self): | |
| """A dictionary, where the keys are edges, and the values are nodes""" | |
| return self[0] | |
| def patterns(self): | |
| """A list of all patterns that currently match at this node""" | |
| return self[1] | |
| class RewriteRule: | |
| """A rewrite rule. | |
| Expresses `lhs` -> `rhs`, for variables `vars`. | |
| Parameters | |
| ---------- | |
| lhs : task | |
| The left-hand-side of the rewrite rule. | |
| rhs : task or function | |
| The right-hand-side of the rewrite rule. If it's a task, variables in | |
| `rhs` will be replaced by terms in the subject that match the variables | |
| in `lhs`. If it's a function, the function will be called with a dict | |
| of such matches. | |
| vars: tuple, optional | |
| Tuple of variables found in the lhs. Variables can be represented as | |
| any hashable object; a good convention is to use strings. If there are | |
| no variables, this can be omitted. | |
| Examples | |
| -------- | |
| Here's a `RewriteRule` to replace all nested calls to `list`, so that | |
| `(list, (list, 'x'))` is replaced with `(list, 'x')`, where `'x'` is a | |
| variable. | |
| >>> import dask.rewrite as dr | |
| >>> lhs = (list, (list, 'x')) | |
| >>> rhs = (list, 'x') | |
| >>> variables = ('x',) | |
| >>> rule = dr.RewriteRule(lhs, rhs, variables) | |
| Here's a more complicated rule that uses a callable right-hand-side. A | |
| callable `rhs` takes in a dictionary mapping variables to their matching | |
| values. This rule replaces all occurrences of `(list, 'x')` with `'x'` if | |
| `'x'` is a list itself. | |
| >>> lhs = (list, 'x') | |
| >>> def repl_list(sd): | |
| ... x = sd['x'] | |
| ... if isinstance(x, list): | |
| ... return x | |
| ... else: | |
| ... return (list, x) | |
| >>> rule = dr.RewriteRule(lhs, repl_list, variables) | |
| """ | |
| def __init__(self, lhs, rhs, vars=()): | |
| if not isinstance(vars, tuple): | |
| raise TypeError("vars must be a tuple of variables") | |
| self.lhs = lhs | |
| if callable(rhs): | |
| self.subs = rhs | |
| else: | |
| self.subs = self._apply | |
| self.rhs = rhs | |
| self._varlist = [t for t in Traverser(lhs) if t in vars] | |
| # Reduce vars down to just variables found in lhs | |
| self.vars = tuple(sorted(set(self._varlist))) | |
| def _apply(self, sub_dict): | |
| term = self.rhs | |
| for key, val in sub_dict.items(): | |
| term = subs(term, key, val) | |
| return term | |
| def __str__(self): | |
| return f"RewriteRule({self.lhs}, {self.rhs}, {self.vars})" | |
| def __repr__(self): | |
| return str(self) | |
| class RuleSet: | |
| """A set of rewrite rules. | |
| Forms a structure for fast rewriting over a set of rewrite rules. This | |
| allows for syntactic matching of terms to patterns for many patterns at | |
| the same time. | |
| Examples | |
| -------- | |
| >>> import dask.rewrite as dr | |
| >>> def f(*args): pass | |
| >>> def g(*args): pass | |
| >>> def h(*args): pass | |
| >>> from operator import add | |
| >>> rs = dr.RuleSet( | |
| ... dr.RewriteRule((add, 'x', 0), 'x', ('x',)), | |
| ... dr.RewriteRule((f, (g, 'x'), 'y'), | |
| ... (h, 'x', 'y'), | |
| ... ('x', 'y'))) | |
| >>> rs.rewrite((add, 2, 0)) | |
| 2 | |
| >>> rs.rewrite((f, (g, 'a', 3))) # doctest: +ELLIPSIS | |
| (<function h at ...>, 'a', 3) | |
| >>> dsk = {'a': (add, 2, 0), | |
| ... 'b': (f, (g, 'a', 3))} | |
| >>> from toolz import valmap | |
| >>> valmap(rs.rewrite, dsk) # doctest: +ELLIPSIS | |
| {'a': 2, 'b': (<function h at ...>, 'a', 3)} | |
| Attributes | |
| ---------- | |
| rules : list | |
| A list of `RewriteRule`s included in the `RuleSet`. | |
| """ | |
| def __init__(self, *rules): | |
| """Create a `RuleSet` for a number of rules | |
| Parameters | |
| ---------- | |
| rules | |
| One or more instances of RewriteRule | |
| """ | |
| self._net = Node() | |
| self.rules = [] | |
| for p in rules: | |
| self.add(p) | |
| def add(self, rule): | |
| """Add a rule to the RuleSet. | |
| Parameters | |
| ---------- | |
| rule : RewriteRule | |
| """ | |
| if not isinstance(rule, RewriteRule): | |
| raise TypeError("rule must be instance of RewriteRule") | |
| vars = rule.vars | |
| curr_node = self._net | |
| ind = len(self.rules) | |
| # List of variables, in order they appear in the POT of the term | |
| for t in Traverser(rule.lhs): | |
| prev_node = curr_node | |
| if t in vars: | |
| t = VAR | |
| if t in curr_node.edges: | |
| curr_node = curr_node.edges[t] | |
| else: | |
| curr_node.edges[t] = Node() | |
| curr_node = curr_node.edges[t] | |
| # We've reached a leaf node. Add the term index to this leaf. | |
| prev_node.edges[t].patterns.append(ind) | |
| self.rules.append(rule) | |
| def iter_matches(self, term): | |
| """A generator that lazily finds matchings for term from the RuleSet. | |
| Parameters | |
| ---------- | |
| term : task | |
| Yields | |
| ------ | |
| Tuples of `(rule, subs)`, where `rule` is the rewrite rule being | |
| matched, and `subs` is a dictionary mapping the variables in the lhs | |
| of the rule to their matching values in the term.""" | |
| S = Traverser(term) | |
| for m, syms in _match(S, self._net): | |
| for i in m: | |
| rule = self.rules[i] | |
| subs = _process_match(rule, syms) | |
| if subs is not None: | |
| yield rule, subs | |
| def _rewrite(self, term): | |
| """Apply the rewrite rules in RuleSet to top level of term""" | |
| for rule, sd in self.iter_matches(term): | |
| # We use for (...) because it's fast in all cases for getting the | |
| # first element from the match iterator. As we only want that | |
| # element, we break here | |
| term = rule.subs(sd) | |
| break | |
| return term | |
| def rewrite(self, task, strategy="bottom_up"): | |
| """Apply the `RuleSet` to `task`. | |
| This applies the most specific matching rule in the RuleSet to the | |
| task, using the provided strategy. | |
| Parameters | |
| ---------- | |
| term: a task | |
| The task to be rewritten | |
| strategy: str, optional | |
| The rewriting strategy to use. Options are "bottom_up" (default), | |
| or "top_level". | |
| Examples | |
| -------- | |
| Suppose there was a function `add` that returned the sum of 2 numbers, | |
| and another function `double` that returned twice its input: | |
| >>> add = lambda x, y: x + y | |
| >>> double = lambda x: 2*x | |
| Now suppose `double` was *significantly* faster than `add`, so | |
| you'd like to replace all expressions `(add, x, x)` with `(double, | |
| x)`, where `x` is a variable. This can be expressed as a rewrite rule: | |
| >>> rule = RewriteRule((add, 'x', 'x'), (double, 'x'), ('x',)) | |
| >>> rs = RuleSet(rule) | |
| This can then be applied to terms to perform the rewriting: | |
| >>> term = (add, (add, 2, 2), (add, 2, 2)) | |
| >>> rs.rewrite(term) # doctest: +SKIP | |
| (double, (double, 2)) | |
| If we only wanted to apply this to the top level of the term, the | |
| `strategy` kwarg can be set to "top_level". | |
| >>> rs.rewrite(term) # doctest: +SKIP | |
| (double, (add, 2, 2)) | |
| """ | |
| return strategies[strategy](self, task) | |
| def _top_level(net, term): | |
| return net._rewrite(term) | |
| def _bottom_up(net, term): | |
| if istask(term): | |
| term = (head(term),) + tuple(_bottom_up(net, t) for t in args(term)) | |
| elif isinstance(term, list): | |
| term = [_bottom_up(net, t) for t in args(term)] | |
| return net._rewrite(term) | |
| strategies = {"top_level": _top_level, "bottom_up": _bottom_up} | |
| def _match(S, N): | |
| """Structural matching of term S to discrimination net node N.""" | |
| stack = deque() | |
| restore_state_flag = False | |
| # matches are stored in a tuple, because all mutations result in a copy, | |
| # preventing operations from changing matches stored on the stack. | |
| matches = () | |
| while True: | |
| if S.current is END: | |
| yield N.patterns, matches | |
| try: | |
| # This try-except block is to catch hashing errors from un-hashable | |
| # types. This allows for variables to be matched with un-hashable | |
| # objects. | |
| n = N.edges.get(S.current, None) | |
| if n and not restore_state_flag: | |
| stack.append((S.copy(), N, matches)) | |
| N = n | |
| S.next() | |
| continue | |
| except TypeError: | |
| pass | |
| n = N.edges.get(VAR, None) | |
| if n: | |
| restore_state_flag = False | |
| matches = matches + (S.term,) | |
| S.skip() | |
| N = n | |
| continue | |
| try: | |
| # Backtrack here | |
| (S, N, matches) = stack.pop() | |
| restore_state_flag = True | |
| except Exception: | |
| return | |
| def _process_match(rule, syms): | |
| """Process a match to determine if it is correct, and to find the correct | |
| substitution that will convert the term into the pattern. | |
| Parameters | |
| ---------- | |
| rule : RewriteRule | |
| syms : iterable | |
| Iterable of subterms that match a corresponding variable. | |
| Returns | |
| ------- | |
| A dictionary of {vars : subterms} describing the substitution to make the | |
| pattern equivalent with the term. Returns `None` if the match is | |
| invalid.""" | |
| subs = {} | |
| varlist = rule._varlist | |
| if not len(varlist) == len(syms): | |
| raise RuntimeError("length of varlist doesn't match length of syms.") | |
| for v, s in zip(varlist, syms): | |
| if v in subs and subs[v] != s: | |
| return None | |
| else: | |
| subs[v] = s | |
| return subs | |