Spaces:
Running
Running
""" Generic Unification algorithm for expression trees with lists of children | |
This implementation is a direct translation of | |
Artificial Intelligence: A Modern Approach by Stuart Russel and Peter Norvig | |
Second edition, section 9.2, page 276 | |
It is modified in the following ways: | |
1. We allow associative and commutative Compound expressions. This results in | |
combinatorial blowup. | |
2. We explore the tree lazily. | |
3. We provide generic interfaces to symbolic algebra libraries in Python. | |
A more traditional version can be found here | |
http://aima.cs.berkeley.edu/python/logic.html | |
""" | |
from sympy.utilities.iterables import kbins | |
class Compound: | |
""" A little class to represent an interior node in the tree | |
This is analogous to SymPy.Basic for non-Atoms | |
""" | |
def __init__(self, op, args): | |
self.op = op | |
self.args = args | |
def __eq__(self, other): | |
return (type(self) is type(other) and self.op == other.op and | |
self.args == other.args) | |
def __hash__(self): | |
return hash((type(self), self.op, self.args)) | |
def __str__(self): | |
return "%s[%s]" % (str(self.op), ', '.join(map(str, self.args))) | |
class Variable: | |
""" A Wild token """ | |
def __init__(self, arg): | |
self.arg = arg | |
def __eq__(self, other): | |
return type(self) is type(other) and self.arg == other.arg | |
def __hash__(self): | |
return hash((type(self), self.arg)) | |
def __str__(self): | |
return "Variable(%s)" % str(self.arg) | |
class CondVariable: | |
""" A wild token that matches conditionally. | |
arg - a wild token. | |
valid - an additional constraining function on a match. | |
""" | |
def __init__(self, arg, valid): | |
self.arg = arg | |
self.valid = valid | |
def __eq__(self, other): | |
return (type(self) is type(other) and | |
self.arg == other.arg and | |
self.valid == other.valid) | |
def __hash__(self): | |
return hash((type(self), self.arg, self.valid)) | |
def __str__(self): | |
return "CondVariable(%s)" % str(self.arg) | |
def unify(x, y, s=None, **fns): | |
""" Unify two expressions. | |
Parameters | |
========== | |
x, y - expression trees containing leaves, Compounds and Variables. | |
s - a mapping of variables to subtrees. | |
Returns | |
======= | |
lazy sequence of mappings {Variable: subtree} | |
Examples | |
======== | |
>>> from sympy.unify.core import unify, Compound, Variable | |
>>> expr = Compound("Add", ("x", "y")) | |
>>> pattern = Compound("Add", ("x", Variable("a"))) | |
>>> next(unify(expr, pattern, {})) | |
{Variable(a): 'y'} | |
""" | |
s = s or {} | |
if x == y: | |
yield s | |
elif isinstance(x, (Variable, CondVariable)): | |
yield from unify_var(x, y, s, **fns) | |
elif isinstance(y, (Variable, CondVariable)): | |
yield from unify_var(y, x, s, **fns) | |
elif isinstance(x, Compound) and isinstance(y, Compound): | |
is_commutative = fns.get('is_commutative', lambda x: False) | |
is_associative = fns.get('is_associative', lambda x: False) | |
for sop in unify(x.op, y.op, s, **fns): | |
if is_associative(x) and is_associative(y): | |
a, b = (x, y) if len(x.args) < len(y.args) else (y, x) | |
if is_commutative(x) and is_commutative(y): | |
combs = allcombinations(a.args, b.args, 'commutative') | |
else: | |
combs = allcombinations(a.args, b.args, 'associative') | |
for aaargs, bbargs in combs: | |
aa = [unpack(Compound(a.op, arg)) for arg in aaargs] | |
bb = [unpack(Compound(b.op, arg)) for arg in bbargs] | |
yield from unify(aa, bb, sop, **fns) | |
elif len(x.args) == len(y.args): | |
yield from unify(x.args, y.args, sop, **fns) | |
elif is_args(x) and is_args(y) and len(x) == len(y): | |
if len(x) == 0: | |
yield s | |
else: | |
for shead in unify(x[0], y[0], s, **fns): | |
yield from unify(x[1:], y[1:], shead, **fns) | |
def unify_var(var, x, s, **fns): | |
if var in s: | |
yield from unify(s[var], x, s, **fns) | |
elif occur_check(var, x): | |
pass | |
elif isinstance(var, CondVariable) and var.valid(x): | |
yield assoc(s, var, x) | |
elif isinstance(var, Variable): | |
yield assoc(s, var, x) | |
def occur_check(var, x): | |
""" var occurs in subtree owned by x? """ | |
if var == x: | |
return True | |
elif isinstance(x, Compound): | |
return occur_check(var, x.args) | |
elif is_args(x): | |
if any(occur_check(var, xi) for xi in x): return True | |
return False | |
def assoc(d, key, val): | |
""" Return copy of d with key associated to val """ | |
d = d.copy() | |
d[key] = val | |
return d | |
def is_args(x): | |
""" Is x a traditional iterable? """ | |
return type(x) in (tuple, list, set) | |
def unpack(x): | |
if isinstance(x, Compound) and len(x.args) == 1: | |
return x.args[0] | |
else: | |
return x | |
def allcombinations(A, B, ordered): | |
""" | |
Restructure A and B to have the same number of elements. | |
Parameters | |
========== | |
ordered must be either 'commutative' or 'associative'. | |
A and B can be rearranged so that the larger of the two lists is | |
reorganized into smaller sublists. | |
Examples | |
======== | |
>>> from sympy.unify.core import allcombinations | |
>>> for x in allcombinations((1, 2, 3), (5, 6), 'associative'): print(x) | |
(((1,), (2, 3)), ((5,), (6,))) | |
(((1, 2), (3,)), ((5,), (6,))) | |
>>> for x in allcombinations((1, 2, 3), (5, 6), 'commutative'): print(x) | |
(((1,), (2, 3)), ((5,), (6,))) | |
(((1, 2), (3,)), ((5,), (6,))) | |
(((1,), (3, 2)), ((5,), (6,))) | |
(((1, 3), (2,)), ((5,), (6,))) | |
(((2,), (1, 3)), ((5,), (6,))) | |
(((2, 1), (3,)), ((5,), (6,))) | |
(((2,), (3, 1)), ((5,), (6,))) | |
(((2, 3), (1,)), ((5,), (6,))) | |
(((3,), (1, 2)), ((5,), (6,))) | |
(((3, 1), (2,)), ((5,), (6,))) | |
(((3,), (2, 1)), ((5,), (6,))) | |
(((3, 2), (1,)), ((5,), (6,))) | |
""" | |
if ordered == "commutative": | |
ordered = 11 | |
if ordered == "associative": | |
ordered = None | |
sm, bg = (A, B) if len(A) < len(B) else (B, A) | |
for part in kbins(list(range(len(bg))), len(sm), ordered=ordered): | |
if bg == B: | |
yield tuple((a,) for a in A), partition(B, part) | |
else: | |
yield partition(A, part), tuple((b,) for b in B) | |
def partition(it, part): | |
""" Partition a tuple/list into pieces defined by indices. | |
Examples | |
======== | |
>>> from sympy.unify.core import partition | |
>>> partition((10, 20, 30, 40), [[0, 1, 2], [3]]) | |
((10, 20, 30), (40,)) | |
""" | |
return type(it)([index(it, ind) for ind in part]) | |
def index(it, ind): | |
""" Fancy indexing into an indexable iterable (tuple, list). | |
Examples | |
======== | |
>>> from sympy.unify.core import index | |
>>> index([10, 20, 30], (1, 2, 0)) | |
[20, 30, 10] | |
""" | |
return type(it)([it[i] for i in ind]) | |