Spaces:
Running
on
Zero
Running
on
Zero
from collections import deque | |
from dask.core import istask, subs | |
def head(task): | |
"""Return the top level node of a task""" | |
if istask(task): | |
return task[0] | |
elif isinstance(task, list): | |
return list | |
else: | |
return task | |
def args(task): | |
"""Get the arguments for the current task""" | |
if istask(task): | |
return task[1:] | |
elif isinstance(task, list): | |
return task | |
else: | |
return () | |
class Traverser: | |
"""Traverser interface for tasks. | |
Class for storing the state while performing a preorder-traversal of a | |
task. | |
Parameters | |
---------- | |
term : task | |
The task to be traversed | |
Attributes | |
---------- | |
term | |
The current element in the traversal | |
current | |
The head of the current element in the traversal. This is simply `head` | |
applied to the attribute `term`. | |
""" | |
def __init__(self, term, stack=None): | |
self.term = term | |
if not stack: | |
self._stack = deque([END]) | |
else: | |
self._stack = stack | |
def __iter__(self): | |
while self.current is not END: | |
yield self.current | |
self.next() | |
def copy(self): | |
"""Copy the traverser in its current state. | |
This allows the traversal to be pushed onto a stack, for easy | |
backtracking.""" | |
return Traverser(self.term, deque(self._stack)) | |
def next(self): | |
"""Proceed to the next term in the preorder traversal.""" | |
subterms = args(self.term) | |
if not subterms: | |
# No subterms, pop off stack | |
self.term = self._stack.pop() | |
else: | |
self.term = subterms[0] | |
self._stack.extend(reversed(subterms[1:])) | |
def current(self): | |
return head(self.term) | |
def skip(self): | |
"""Skip over all subterms of the current level in the traversal""" | |
self.term = self._stack.pop() | |
class Token: | |
"""A token object. | |
Used to express certain objects in the traversal of a task or pattern.""" | |
def __init__(self, name): | |
self.name = name | |
def __repr__(self): | |
return self.name | |
# A variable to represent *all* variables in a discrimination net | |
VAR = Token("?") | |
# Represents the end of the traversal of an expression. We can't use `None`, | |
# 'False', etc... here, as anything may be an argument to a function. | |
END = Token("end") | |
class Node(tuple): | |
"""A Discrimination Net node.""" | |
__slots__ = () | |
def __new__(cls, edges=None, patterns=None): | |
edges = edges if edges else {} | |
patterns = patterns if patterns else [] | |
return tuple.__new__(cls, (edges, patterns)) | |
def edges(self): | |
"""A dictionary, where the keys are edges, and the values are nodes""" | |
return self[0] | |
def patterns(self): | |
"""A list of all patterns that currently match at this node""" | |
return self[1] | |
class RewriteRule: | |
"""A rewrite rule. | |
Expresses `lhs` -> `rhs`, for variables `vars`. | |
Parameters | |
---------- | |
lhs : task | |
The left-hand-side of the rewrite rule. | |
rhs : task or function | |
The right-hand-side of the rewrite rule. If it's a task, variables in | |
`rhs` will be replaced by terms in the subject that match the variables | |
in `lhs`. If it's a function, the function will be called with a dict | |
of such matches. | |
vars: tuple, optional | |
Tuple of variables found in the lhs. Variables can be represented as | |
any hashable object; a good convention is to use strings. If there are | |
no variables, this can be omitted. | |
Examples | |
-------- | |
Here's a `RewriteRule` to replace all nested calls to `list`, so that | |
`(list, (list, 'x'))` is replaced with `(list, 'x')`, where `'x'` is a | |
variable. | |
>>> import dask.rewrite as dr | |
>>> lhs = (list, (list, 'x')) | |
>>> rhs = (list, 'x') | |
>>> variables = ('x',) | |
>>> rule = dr.RewriteRule(lhs, rhs, variables) | |
Here's a more complicated rule that uses a callable right-hand-side. A | |
callable `rhs` takes in a dictionary mapping variables to their matching | |
values. This rule replaces all occurrences of `(list, 'x')` with `'x'` if | |
`'x'` is a list itself. | |
>>> lhs = (list, 'x') | |
>>> def repl_list(sd): | |
... x = sd['x'] | |
... if isinstance(x, list): | |
... return x | |
... else: | |
... return (list, x) | |
>>> rule = dr.RewriteRule(lhs, repl_list, variables) | |
""" | |
def __init__(self, lhs, rhs, vars=()): | |
if not isinstance(vars, tuple): | |
raise TypeError("vars must be a tuple of variables") | |
self.lhs = lhs | |
if callable(rhs): | |
self.subs = rhs | |
else: | |
self.subs = self._apply | |
self.rhs = rhs | |
self._varlist = [t for t in Traverser(lhs) if t in vars] | |
# Reduce vars down to just variables found in lhs | |
self.vars = tuple(sorted(set(self._varlist))) | |
def _apply(self, sub_dict): | |
term = self.rhs | |
for key, val in sub_dict.items(): | |
term = subs(term, key, val) | |
return term | |
def __str__(self): | |
return f"RewriteRule({self.lhs}, {self.rhs}, {self.vars})" | |
def __repr__(self): | |
return str(self) | |
class RuleSet: | |
"""A set of rewrite rules. | |
Forms a structure for fast rewriting over a set of rewrite rules. This | |
allows for syntactic matching of terms to patterns for many patterns at | |
the same time. | |
Examples | |
-------- | |
>>> import dask.rewrite as dr | |
>>> def f(*args): pass | |
>>> def g(*args): pass | |
>>> def h(*args): pass | |
>>> from operator import add | |
>>> rs = dr.RuleSet( | |
... dr.RewriteRule((add, 'x', 0), 'x', ('x',)), | |
... dr.RewriteRule((f, (g, 'x'), 'y'), | |
... (h, 'x', 'y'), | |
... ('x', 'y'))) | |
>>> rs.rewrite((add, 2, 0)) | |
2 | |
>>> rs.rewrite((f, (g, 'a', 3))) # doctest: +ELLIPSIS | |
(<function h at ...>, 'a', 3) | |
>>> dsk = {'a': (add, 2, 0), | |
... 'b': (f, (g, 'a', 3))} | |
>>> from toolz import valmap | |
>>> valmap(rs.rewrite, dsk) # doctest: +ELLIPSIS | |
{'a': 2, 'b': (<function h at ...>, 'a', 3)} | |
Attributes | |
---------- | |
rules : list | |
A list of `RewriteRule`s included in the `RuleSet`. | |
""" | |
def __init__(self, *rules): | |
"""Create a `RuleSet` for a number of rules | |
Parameters | |
---------- | |
rules | |
One or more instances of RewriteRule | |
""" | |
self._net = Node() | |
self.rules = [] | |
for p in rules: | |
self.add(p) | |
def add(self, rule): | |
"""Add a rule to the RuleSet. | |
Parameters | |
---------- | |
rule : RewriteRule | |
""" | |
if not isinstance(rule, RewriteRule): | |
raise TypeError("rule must be instance of RewriteRule") | |
vars = rule.vars | |
curr_node = self._net | |
ind = len(self.rules) | |
# List of variables, in order they appear in the POT of the term | |
for t in Traverser(rule.lhs): | |
prev_node = curr_node | |
if t in vars: | |
t = VAR | |
if t in curr_node.edges: | |
curr_node = curr_node.edges[t] | |
else: | |
curr_node.edges[t] = Node() | |
curr_node = curr_node.edges[t] | |
# We've reached a leaf node. Add the term index to this leaf. | |
prev_node.edges[t].patterns.append(ind) | |
self.rules.append(rule) | |
def iter_matches(self, term): | |
"""A generator that lazily finds matchings for term from the RuleSet. | |
Parameters | |
---------- | |
term : task | |
Yields | |
------ | |
Tuples of `(rule, subs)`, where `rule` is the rewrite rule being | |
matched, and `subs` is a dictionary mapping the variables in the lhs | |
of the rule to their matching values in the term.""" | |
S = Traverser(term) | |
for m, syms in _match(S, self._net): | |
for i in m: | |
rule = self.rules[i] | |
subs = _process_match(rule, syms) | |
if subs is not None: | |
yield rule, subs | |
def _rewrite(self, term): | |
"""Apply the rewrite rules in RuleSet to top level of term""" | |
for rule, sd in self.iter_matches(term): | |
# We use for (...) because it's fast in all cases for getting the | |
# first element from the match iterator. As we only want that | |
# element, we break here | |
term = rule.subs(sd) | |
break | |
return term | |
def rewrite(self, task, strategy="bottom_up"): | |
"""Apply the `RuleSet` to `task`. | |
This applies the most specific matching rule in the RuleSet to the | |
task, using the provided strategy. | |
Parameters | |
---------- | |
term: a task | |
The task to be rewritten | |
strategy: str, optional | |
The rewriting strategy to use. Options are "bottom_up" (default), | |
or "top_level". | |
Examples | |
-------- | |
Suppose there was a function `add` that returned the sum of 2 numbers, | |
and another function `double` that returned twice its input: | |
>>> add = lambda x, y: x + y | |
>>> double = lambda x: 2*x | |
Now suppose `double` was *significantly* faster than `add`, so | |
you'd like to replace all expressions `(add, x, x)` with `(double, | |
x)`, where `x` is a variable. This can be expressed as a rewrite rule: | |
>>> rule = RewriteRule((add, 'x', 'x'), (double, 'x'), ('x',)) | |
>>> rs = RuleSet(rule) | |
This can then be applied to terms to perform the rewriting: | |
>>> term = (add, (add, 2, 2), (add, 2, 2)) | |
>>> rs.rewrite(term) # doctest: +SKIP | |
(double, (double, 2)) | |
If we only wanted to apply this to the top level of the term, the | |
`strategy` kwarg can be set to "top_level". | |
>>> rs.rewrite(term) # doctest: +SKIP | |
(double, (add, 2, 2)) | |
""" | |
return strategies[strategy](self, task) | |
def _top_level(net, term): | |
return net._rewrite(term) | |
def _bottom_up(net, term): | |
if istask(term): | |
term = (head(term),) + tuple(_bottom_up(net, t) for t in args(term)) | |
elif isinstance(term, list): | |
term = [_bottom_up(net, t) for t in args(term)] | |
return net._rewrite(term) | |
strategies = {"top_level": _top_level, "bottom_up": _bottom_up} | |
def _match(S, N): | |
"""Structural matching of term S to discrimination net node N.""" | |
stack = deque() | |
restore_state_flag = False | |
# matches are stored in a tuple, because all mutations result in a copy, | |
# preventing operations from changing matches stored on the stack. | |
matches = () | |
while True: | |
if S.current is END: | |
yield N.patterns, matches | |
try: | |
# This try-except block is to catch hashing errors from un-hashable | |
# types. This allows for variables to be matched with un-hashable | |
# objects. | |
n = N.edges.get(S.current, None) | |
if n and not restore_state_flag: | |
stack.append((S.copy(), N, matches)) | |
N = n | |
S.next() | |
continue | |
except TypeError: | |
pass | |
n = N.edges.get(VAR, None) | |
if n: | |
restore_state_flag = False | |
matches = matches + (S.term,) | |
S.skip() | |
N = n | |
continue | |
try: | |
# Backtrack here | |
(S, N, matches) = stack.pop() | |
restore_state_flag = True | |
except Exception: | |
return | |
def _process_match(rule, syms): | |
"""Process a match to determine if it is correct, and to find the correct | |
substitution that will convert the term into the pattern. | |
Parameters | |
---------- | |
rule : RewriteRule | |
syms : iterable | |
Iterable of subterms that match a corresponding variable. | |
Returns | |
------- | |
A dictionary of {vars : subterms} describing the substitution to make the | |
pattern equivalent with the term. Returns `None` if the match is | |
invalid.""" | |
subs = {} | |
varlist = rule._varlist | |
if not len(varlist) == len(syms): | |
raise RuntimeError("length of varlist doesn't match length of syms.") | |
for v, s in zip(varlist, syms): | |
if v in subs and subs[v] != s: | |
return None | |
else: | |
subs[v] = s | |
return subs | |