Spaces:
Running
Running
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A key-value[] store that implements reservoir sampling on the values.""" | |
import collections | |
import random | |
import threading | |
class Reservoir: | |
"""A map-to-arrays container, with deterministic Reservoir Sampling. | |
Items are added with an associated key. Items may be retrieved by key, and | |
a list of keys can also be retrieved. If size is not zero, then it dictates | |
the maximum number of items that will be stored with each key. Once there are | |
more items for a given key, they are replaced via reservoir sampling, such | |
that each item has an equal probability of being included in the sample. | |
Deterministic means that for any given seed and bucket size, the sequence of | |
values that are kept for any given tag will always be the same, and that this | |
is independent of any insertions on other tags. That is: | |
>>> separate_reservoir = reservoir.Reservoir(10) | |
>>> interleaved_reservoir = reservoir.Reservoir(10) | |
>>> for i in range(100): | |
>>> separate_reservoir.AddItem('key1', i) | |
>>> for i in range(100): | |
>>> separate_reservoir.AddItem('key2', i) | |
>>> for i in range(100): | |
>>> interleaved_reservoir.AddItem('key1', i) | |
>>> interleaved_reservoir.AddItem('key2', i) | |
separate_reservoir and interleaved_reservoir will be in identical states. | |
See: https://en.wikipedia.org/wiki/Reservoir_sampling | |
Adding items has amortized O(1) runtime. | |
Fields: | |
always_keep_last: Whether the latest seen sample is always at the | |
end of the reservoir. Defaults to True. | |
size: An integer of the maximum number of samples. | |
""" | |
def __init__(self, size, seed=0, always_keep_last=True): | |
"""Creates a new reservoir. | |
Args: | |
size: The number of values to keep in the reservoir for each tag. If 0, | |
all values will be kept. | |
seed: The seed of the random number generator to use when sampling. | |
Different values for |seed| will produce different samples from the same | |
input items. | |
always_keep_last: Whether to always keep the latest seen item in the | |
end of the reservoir. Defaults to True. | |
Raises: | |
ValueError: If size is negative or not an integer. | |
""" | |
if size < 0 or size != round(size): | |
raise ValueError("size must be nonnegative integer, was %s" % size) | |
self._buckets = collections.defaultdict( | |
lambda: _ReservoirBucket( | |
size, random.Random(seed), always_keep_last | |
) | |
) | |
# _mutex guards the keys - creating new keys, retrieving by key, etc | |
# the internal items are guarded by the ReservoirBuckets' internal mutexes | |
self._mutex = threading.Lock() | |
self.size = size | |
self.always_keep_last = always_keep_last | |
def Keys(self): | |
"""Return all the keys in the reservoir. | |
Returns: | |
['list', 'of', 'keys'] in the Reservoir. | |
""" | |
with self._mutex: | |
return list(self._buckets.keys()) | |
def Items(self, key): | |
"""Return items associated with given key. | |
Args: | |
key: The key for which we are finding associated items. | |
Raises: | |
KeyError: If the key is not found in the reservoir. | |
Returns: | |
[list, of, items] associated with that key. | |
""" | |
with self._mutex: | |
if key not in self._buckets: | |
raise KeyError("Key %s was not found in Reservoir" % key) | |
bucket = self._buckets[key] | |
return bucket.Items() | |
def AddItem(self, key, item, f=lambda x: x): | |
"""Add a new item to the Reservoir with the given tag. | |
If the reservoir has not yet reached full size, the new item is guaranteed | |
to be added. If the reservoir is full, then behavior depends on the | |
always_keep_last boolean. | |
If always_keep_last was set to true, the new item is guaranteed to be added | |
to the reservoir, and either the previous last item will be replaced, or | |
(with low probability) an older item will be replaced. | |
If always_keep_last was set to false, then the new item will replace an | |
old item with low probability. | |
If f is provided, it will be applied to transform item (lazily, iff item is | |
going to be included in the reservoir). | |
Args: | |
key: The key to store the item under. | |
item: The item to add to the reservoir. | |
f: An optional function to transform the item prior to addition. | |
""" | |
with self._mutex: | |
bucket = self._buckets[key] | |
bucket.AddItem(item, f) | |
def FilterItems(self, filterFn, key=None): | |
"""Filter items within a Reservoir, using a filtering function. | |
Args: | |
filterFn: A function that returns True for the items to be kept. | |
key: An optional bucket key to filter. If not specified, will filter all | |
all buckets. | |
Returns: | |
The number of items removed. | |
""" | |
with self._mutex: | |
if key: | |
if key in self._buckets: | |
return self._buckets[key].FilterItems(filterFn) | |
else: | |
return 0 | |
else: | |
return sum( | |
bucket.FilterItems(filterFn) | |
for bucket in self._buckets.values() | |
) | |
class _ReservoirBucket: | |
"""A container for items from a stream, that implements reservoir sampling. | |
It always stores the most recent item as its final item. | |
""" | |
def __init__(self, _max_size, _random=None, always_keep_last=True): | |
"""Create the _ReservoirBucket. | |
Args: | |
_max_size: The maximum size the reservoir bucket may grow to. If size is | |
zero, the bucket has unbounded size. | |
_random: The random number generator to use. If not specified, defaults to | |
random.Random(0). | |
always_keep_last: Whether the latest seen item should always be included | |
in the end of the bucket. | |
Raises: | |
ValueError: if the size is not a nonnegative integer. | |
""" | |
if _max_size < 0 or _max_size != round(_max_size): | |
raise ValueError( | |
"_max_size must be nonnegative int, was %s" % _max_size | |
) | |
self.items = [] | |
# This mutex protects the internal items, ensuring that calls to Items and | |
# AddItem are thread-safe | |
self._mutex = threading.Lock() | |
self._max_size = _max_size | |
self._num_items_seen = 0 | |
if _random is not None: | |
self._random = _random | |
else: | |
self._random = random.Random(0) | |
self.always_keep_last = always_keep_last | |
def AddItem(self, item, f=lambda x: x): | |
"""Add an item to the ReservoirBucket, replacing an old item if | |
necessary. | |
The new item is guaranteed to be added to the bucket, and to be the last | |
element in the bucket. If the bucket has reached capacity, then an old item | |
will be replaced. With probability (_max_size/_num_items_seen) a random item | |
in the bucket will be popped out and the new item will be appended | |
to the end. With probability (1 - _max_size/_num_items_seen) | |
the last item in the bucket will be replaced. | |
Since the O(n) replacements occur with O(1/_num_items_seen) likelihood, | |
the amortized runtime is O(1). | |
Args: | |
item: The item to add to the bucket. | |
f: A function to transform item before addition, if it will be kept in | |
the reservoir. | |
""" | |
with self._mutex: | |
if len(self.items) < self._max_size or self._max_size == 0: | |
self.items.append(f(item)) | |
else: | |
r = self._random.randint(0, self._num_items_seen) | |
if r < self._max_size: | |
self.items.pop(r) | |
self.items.append(f(item)) | |
elif self.always_keep_last: | |
self.items[-1] = f(item) | |
self._num_items_seen += 1 | |
def FilterItems(self, filterFn): | |
"""Filter items in a ReservoirBucket, using a filtering function. | |
Filtering items from the reservoir bucket must update the | |
internal state variable self._num_items_seen, which is used for determining | |
the rate of replacement in reservoir sampling. Ideally, self._num_items_seen | |
would contain the exact number of items that have ever seen by the | |
ReservoirBucket and satisfy filterFn. However, the ReservoirBucket does not | |
have access to all items seen -- it only has access to the subset of items | |
that have survived sampling (self.items). Therefore, we estimate | |
self._num_items_seen by scaling it by the same ratio as the ratio of items | |
not removed from self.items. | |
Args: | |
filterFn: A function that returns True for items to be kept. | |
Returns: | |
The number of items removed from the bucket. | |
""" | |
with self._mutex: | |
size_before = len(self.items) | |
self.items = list(filter(filterFn, self.items)) | |
size_diff = size_before - len(self.items) | |
# Estimate a correction the number of items seen | |
prop_remaining = ( | |
len(self.items) / float(size_before) if size_before > 0 else 0 | |
) | |
self._num_items_seen = int( | |
round(self._num_items_seen * prop_remaining) | |
) | |
return size_diff | |
def Items(self): | |
"""Get all the items in the bucket.""" | |
with self._mutex: | |
return list(self.items) | |