File size: 3,486 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from .core import unify, reify  # type: ignore[attr-defined]
from .variable import isvar
from .utils import _toposort, freeze
from .unification_tools import groupby, first  # type: ignore[import]


class Dispatcher:
    def __init__(self, name):
        self.name = name
        self.funcs = {}
        self.ordering = []

    def add(self, signature, func):
        self.funcs[freeze(signature)] = func
        self.ordering = ordering(self.funcs)

    def __call__(self, *args, **kwargs):
        func, s = self.resolve(args)
        return func(*args, **kwargs)

    def resolve(self, args):
        n = len(args)
        for signature in self.ordering:
            if len(signature) != n:
                continue
            s = unify(freeze(args), signature)
            if s is not False:
                result = self.funcs[signature]
                return result, s
        raise NotImplementedError("No match found. \nKnown matches: "
                                  + str(self.ordering) + "\nInput: " + str(args))

    def register(self, *signature):
        def _(func):
            self.add(signature, func)
            return self
        return _


class VarDispatcher(Dispatcher):
    """ A dispatcher that calls functions with variable names

    >>> # xdoctest: +SKIP

    >>> d = VarDispatcher('d')

    >>> x = var('x')

    >>> @d.register('inc', x)

    ... def f(x):

    ...     return x + 1

    >>> @d.register('double', x)

    ... def f(x):

    ...     return x * 2

    >>> d('inc', 10)

    11

    >>> d('double', 10)

    20

    """
    def __call__(self, *args, **kwargs):
        func, s = self.resolve(args)
        d = {k.token: v for k, v in s.items()}
        return func(**d)


global_namespace = {}  # type: ignore[var-annotated]


def match(*signature, **kwargs):
    namespace = kwargs.get('namespace', global_namespace)
    dispatcher = kwargs.get('Dispatcher', Dispatcher)

    def _(func):
        name = func.__name__

        if name not in namespace:
            namespace[name] = dispatcher(name)
        d = namespace[name]

        d.add(signature, func)

        return d
    return _


def supercedes(a, b):
    """ ``a`` is a more specific match than ``b`` """
    if isvar(b) and not isvar(a):
        return True
    s = unify(a, b)
    if s is False:
        return False
    s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
    if reify(a, s) == a:
        return True
    if reify(b, s) == b:
        return False


# Taken from multipledispatch
def edge(a, b, tie_breaker=hash):
    """ A should be checked before B

    Tie broken by tie_breaker, defaults to ``hash``

    """
    if supercedes(a, b):
        if supercedes(b, a):
            return tie_breaker(a) > tie_breaker(b)
        else:
            return True
    return False


# Taken from multipledispatch
def ordering(signatures):
    """ A sane ordering of signatures to check, first to last

    Topological sort of edges as given by ``edge`` and ``supercedes``

    """
    signatures = list(map(tuple, signatures))
    edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
    edges = groupby(first, edges)
    for s in signatures:
        if s not in edges:
            edges[s] = []
    edges = {k: [b for a, b in v] for k, v in edges.items()}  # type: ignore[attr-defined, assignment]
    return _toposort(edges)