|
|
|
|
|
|
|
import warnings |
|
from itertools import islice |
|
from numbers import Integral |
|
|
|
import numpy as np |
|
|
|
from .._config import get_config |
|
from ._param_validation import Interval, validate_params |
|
|
|
|
|
def chunk_generator(gen, chunksize): |
|
"""Chunk generator, ``gen`` into lists of length ``chunksize``. The last |
|
chunk may have a length less than ``chunksize``.""" |
|
while True: |
|
chunk = list(islice(gen, chunksize)) |
|
if chunk: |
|
yield chunk |
|
else: |
|
return |
|
|
|
|
|
@validate_params( |
|
{ |
|
"n": [Interval(Integral, 1, None, closed="left")], |
|
"batch_size": [Interval(Integral, 1, None, closed="left")], |
|
"min_batch_size": [Interval(Integral, 0, None, closed="left")], |
|
}, |
|
prefer_skip_nested_validation=True, |
|
) |
|
def gen_batches(n, batch_size, *, min_batch_size=0): |
|
"""Generator to create slices containing `batch_size` elements from 0 to `n`. |
|
|
|
The last slice may contain less than `batch_size` elements, when |
|
`batch_size` does not divide `n`. |
|
|
|
Parameters |
|
---------- |
|
n : int |
|
Size of the sequence. |
|
batch_size : int |
|
Number of elements in each batch. |
|
min_batch_size : int, default=0 |
|
Minimum number of elements in each batch. |
|
|
|
Yields |
|
------ |
|
slice of `batch_size` elements |
|
|
|
See Also |
|
-------- |
|
gen_even_slices: Generator to create n_packs slices going up to n. |
|
|
|
Examples |
|
-------- |
|
>>> from sklearn.utils import gen_batches |
|
>>> list(gen_batches(7, 3)) |
|
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)] |
|
>>> list(gen_batches(6, 3)) |
|
[slice(0, 3, None), slice(3, 6, None)] |
|
>>> list(gen_batches(2, 3)) |
|
[slice(0, 2, None)] |
|
>>> list(gen_batches(7, 3, min_batch_size=0)) |
|
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)] |
|
>>> list(gen_batches(7, 3, min_batch_size=2)) |
|
[slice(0, 3, None), slice(3, 7, None)] |
|
""" |
|
start = 0 |
|
for _ in range(int(n // batch_size)): |
|
end = start + batch_size |
|
if end + min_batch_size > n: |
|
continue |
|
yield slice(start, end) |
|
start = end |
|
if start < n: |
|
yield slice(start, n) |
|
|
|
|
|
@validate_params( |
|
{ |
|
"n": [Interval(Integral, 1, None, closed="left")], |
|
"n_packs": [Interval(Integral, 1, None, closed="left")], |
|
"n_samples": [Interval(Integral, 1, None, closed="left"), None], |
|
}, |
|
prefer_skip_nested_validation=True, |
|
) |
|
def gen_even_slices(n, n_packs, *, n_samples=None): |
|
"""Generator to create `n_packs` evenly spaced slices going up to `n`. |
|
|
|
If `n_packs` does not divide `n`, except for the first `n % n_packs` |
|
slices, remaining slices may contain fewer elements. |
|
|
|
Parameters |
|
---------- |
|
n : int |
|
Size of the sequence. |
|
n_packs : int |
|
Number of slices to generate. |
|
n_samples : int, default=None |
|
Number of samples. Pass `n_samples` when the slices are to be used for |
|
sparse matrix indexing; slicing off-the-end raises an exception, while |
|
it works for NumPy arrays. |
|
|
|
Yields |
|
------ |
|
`slice` representing a set of indices from 0 to n. |
|
|
|
See Also |
|
-------- |
|
gen_batches: Generator to create slices containing batch_size elements |
|
from 0 to n. |
|
|
|
Examples |
|
-------- |
|
>>> from sklearn.utils import gen_even_slices |
|
>>> list(gen_even_slices(10, 1)) |
|
[slice(0, 10, None)] |
|
>>> list(gen_even_slices(10, 10)) |
|
[slice(0, 1, None), slice(1, 2, None), ..., slice(9, 10, None)] |
|
>>> list(gen_even_slices(10, 5)) |
|
[slice(0, 2, None), slice(2, 4, None), ..., slice(8, 10, None)] |
|
>>> list(gen_even_slices(10, 3)) |
|
[slice(0, 4, None), slice(4, 7, None), slice(7, 10, None)] |
|
""" |
|
start = 0 |
|
for pack_num in range(n_packs): |
|
this_n = n // n_packs |
|
if pack_num < n % n_packs: |
|
this_n += 1 |
|
if this_n > 0: |
|
end = start + this_n |
|
if n_samples is not None: |
|
end = min(n_samples, end) |
|
yield slice(start, end, None) |
|
start = end |
|
|
|
|
|
def get_chunk_n_rows(row_bytes, *, max_n_rows=None, working_memory=None): |
|
"""Calculate how many rows can be processed within `working_memory`. |
|
|
|
Parameters |
|
---------- |
|
row_bytes : int |
|
The expected number of bytes of memory that will be consumed |
|
during the processing of each row. |
|
max_n_rows : int, default=None |
|
The maximum return value. |
|
working_memory : int or float, default=None |
|
The number of rows to fit inside this number of MiB will be |
|
returned. When None (default), the value of |
|
``sklearn.get_config()['working_memory']`` is used. |
|
|
|
Returns |
|
------- |
|
int |
|
The number of rows which can be processed within `working_memory`. |
|
|
|
Warns |
|
----- |
|
Issues a UserWarning if `row_bytes exceeds `working_memory` MiB. |
|
""" |
|
|
|
if working_memory is None: |
|
working_memory = get_config()["working_memory"] |
|
|
|
chunk_n_rows = int(working_memory * (2**20) // row_bytes) |
|
if max_n_rows is not None: |
|
chunk_n_rows = min(chunk_n_rows, max_n_rows) |
|
if chunk_n_rows < 1: |
|
warnings.warn( |
|
"Could not adhere to working_memory config. " |
|
"Currently %.0fMiB, %.0fMiB required." |
|
% (working_memory, np.ceil(row_bytes * 2**-20)) |
|
) |
|
chunk_n_rows = 1 |
|
return chunk_n_rows |
|
|