File size: 6,717 Bytes
d1ed09d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from __future__ import annotations

import heapq
import itertools
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Hashable, Iterator
from typing import MutableSet  # TODO move to collections.abc (requires Python >=3.9)
from typing import Any, TypeVar, cast

T = TypeVar("T", bound=Hashable)


# TODO change to UserDict[K, V] (requires Python >=3.9)
class LRU(UserDict):
    """Limited size mapping, evicting the least recently looked-up key when full"""

    def __init__(self, maxsize: float):
        super().__init__()
        self.data = OrderedDict()
        self.maxsize = maxsize

    def __getitem__(self, key):
        value = super().__getitem__(key)
        cast(OrderedDict, self.data).move_to_end(key)
        return value

    def __setitem__(self, key, value):
        if len(self) >= self.maxsize:
            cast(OrderedDict, self.data).popitem(last=False)
        super().__setitem__(key, value)


class HeapSet(MutableSet[T]):
    """A set-like where the `pop` method returns the smallest item, as sorted by an
    arbitrary key function. Ties are broken by oldest first.

    Values must be compatible with :mod:`weakref`.

    Parameters
    ----------
    key: Callable
        A function that takes a single element of the collection as a parameter and
        returns a sorting key. The key does not need to be hashable and does not need to
        support :mod:`weakref`.

    Note
    ----
    The key returned for each element should not to change over time. If it does, the
    position in the heap won't change, even if the element is re-added, and it *may* not
    change even if it's discarded and then re-added later.
    """

    __slots__ = ("key", "_data", "_heap", "_inc", "_sorted")
    key: Callable[[T], Any]
    _data: set[T]
    _heap: list[tuple[Any, int, weakref.ref[T]]]
    _inc: int
    _sorted: bool

    def __init__(self, *, key: Callable[[T], Any]):
        self.key = key
        self._data = set()
        self._inc = 0
        self._heap = []
        self._sorted = True

    def __repr__(self) -> str:
        return f"<{type(self).__name__}: {len(self)} items>"

    def __reduce__(self) -> tuple[Callable, tuple]:
        heap = [(k, i, v) for k, i, vref in self._heap if (v := vref()) in self._data]
        return HeapSet._unpickle, (self.key, self._inc, heap)

    @staticmethod
    def _unpickle(
        key: Callable[[T], Any], inc: int, heap: list[tuple[Any, int, T]]
    ) -> HeapSet[T]:
        self = object.__new__(HeapSet)
        self.key = key
        self._data = {v for _, _, v in heap}
        self._inc = inc
        self._heap = [(k, i, weakref.ref(v)) for k, i, v in heap]
        heapq.heapify(self._heap)
        self._sorted = not heap
        return self

    def __contains__(self, value: object) -> bool:
        return value in self._data

    def __len__(self) -> int:
        return len(self._data)

    def add(self, value: T) -> None:
        if value in self._data:
            return
        k = self.key(value)
        vref = weakref.ref(value)
        heapq.heappush(self._heap, (k, self._inc, vref))
        self._sorted = False
        self._data.add(value)
        self._inc += 1

    def discard(self, value: T) -> None:
        self._data.discard(value)
        if not self._data:
            self.clear()

    def peek(self) -> T:
        """Return the smallest element without removing it"""
        if not self._data:
            raise KeyError("peek into empty set")
        while True:
            value = self._heap[0][2]()
            if value in self._data:
                return value
            heapq.heappop(self._heap)
            self._sorted = False

    def peekn(self, n: int) -> Iterator[T]:
        """Iterate over the n smallest elements without removing them.
        This is O(1) for n == 1; O(n*logn) otherwise.
        """
        if n <= 0 or not self:
            return  # empty iterator
        if n == 1:
            yield self.peek()
        else:
            # NOTE: we could pop N items off the queue, then push them back.
            # But copying the list N times is probably slower than just sorting it
            # with fast C code.
            # If we had a `heappop` that sliced the list instead of popping from it,
            # we could implement an optimized version for small `n`s.
            yield from itertools.islice(self.sorted(), n)

    def pop(self) -> T:
        if not self._data:
            raise KeyError("pop from an empty set")
        while True:
            _, _, vref = heapq.heappop(self._heap)
            self._sorted = False
            value = vref()
            if value in self._data:
                self._data.discard(value)
                if not self._data:
                    self.clear()
                return value

    def peekright(self) -> T:
        """Return one of the largest elements (not necessarily the largest!) without
        removing it. It's guaranteed that ``self.peekright() >= self.peek()``.
        """
        if not self._data:
            raise KeyError("peek into empty set")
        while True:
            value = self._heap[-1][2]()
            if value in self._data:
                return value
            del self._heap[-1]

    def popright(self) -> T:
        """Remove and return one of the largest elements (not necessarily the largest!)
        It's guaranteed that ``self.popright() >= self.peek()``.
        """
        if not self._data:
            raise KeyError("pop from an empty set")
        while True:
            _, _, vref = self._heap.pop()
            value = vref()
            if value in self._data:
                self._data.discard(value)
                if not self._data:
                    self.clear()
                return value

    def __iter__(self) -> Iterator[T]:
        """Iterate over all elements. This is a O(n) operation which returns the
        elements in pseudo-random order.
        """
        return iter(self._data)

    def sorted(self) -> Iterator[T]:
        """Iterate over all elements. This is a O(n*logn) operation which returns the
        elements in order, from smallest to largest according to the key and insertion
        order.
        """
        if not self._sorted:
            self._heap.sort()  # A sorted list maintains the heap invariant
            self._sorted = True
        seen = set()
        for _, _, vref in self._heap:
            value = vref()
            if value in self._data and value not in seen:
                yield value
                seen.add(value)

    def clear(self) -> None:
        self._data.clear()
        self._heap.clear()
        self._sorted = True