File size: 619 Bytes
2eafbc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from typing import Generator, Iterable, List, TypeVar, Union

B = TypeVar("B")


def calculate_input_elements(input_value: Union[B, List[B]]) -> int:
    return len(input_value) if issubclass(type(input_value), list) else 1


def create_batches(
    sequence: Iterable[B], batch_size: int
) -> Generator[List[B], None, None]:
    batch_size = max(batch_size, 1)
    current_batch = []
    for element in sequence:
        if len(current_batch) == batch_size:
            yield current_batch
            current_batch = []
        current_batch.append(element)
    if len(current_batch) > 0:
        yield current_batch