File size: 3,007 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
def hashable(x):
    try:
        hash(x)
        return True
    except TypeError:
        return False


def transitive_get(key, d):
    """ Transitive dict.get

    >>> d = {1: 2, 2: 3, 3: 4}

    >>> d.get(1)

    2

    >>> transitive_get(1, d)

    4

    """
    while hashable(key) and key in d:
        key = d[key]
    return key


def raises(err, lamda):
    try:
        lamda()
        return False
    except err:
        return True


# Taken from theano/theano/gof/sched.py
# Avoids licensing issues because this was written by Matthew Rocklin
def _toposort(edges):
    """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)

    inputs:

        edges - a dict of the form {a: {b, c}} where b and c depend on a

    outputs:

        L - an ordered list of nodes that satisfy the dependencies of edges

    >>> # xdoctest: +SKIP

    >>> _toposort({1: (2, 3), 2: (3, )})

    [1, 2, 3]

    Closely follows the wikipedia page [2]

    [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",

    Communications of the ACM

    [2] http://en.wikipedia.org/wiki/Toposort#Algorithms

    """
    incoming_edges = reverse_dict(edges)
    incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
    S = ({v for v in edges if v not in incoming_edges})
    L = []

    while S:
        n = S.pop()
        L.append(n)
        for m in edges.get(n, ()):
            assert n in incoming_edges[m]
            incoming_edges[m].remove(n)
            if not incoming_edges[m]:
                S.add(m)
    if any(incoming_edges.get(v, None) for v in edges):
        raise ValueError("Input has cycles")
    return L


def reverse_dict(d):
    """Reverses direction of dependence dict

    >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}

    >>> reverse_dict(d)  # doctest: +SKIP

    {1: ('a',), 2: ('a', 'b'), 3: ('b',)}

    :note: dict order are not deterministic. As we iterate on the

        input dict, it make the output of this function depend on the

        dict order. So this function output order should be considered

        as undeterministic.

    """
    result = {}  # type: ignore[var-annotated]
    for key in d:
        for val in d[key]:
            result[val] = result.get(val, tuple()) + (key, )
    return result


def xfail(func):
    try:
        func()
        raise Exception("XFailed test passed")  # pragma:nocover
    except Exception:
        pass


def freeze(d):
    """ Freeze container to hashable form

    >>> freeze(1)

    1

    >>> freeze([1, 2])

    (1, 2)

    >>> freeze({1: 2}) # doctest: +SKIP

    frozenset([(1, 2)])

    """
    if isinstance(d, dict):
        return frozenset(map(freeze, d.items()))
    if isinstance(d, set):
        return frozenset(map(freeze, d))
    if isinstance(d, (tuple, list)):
        return tuple(map(freeze, d))
    return d