Spaces:
Runtime error
Runtime error
File size: 619 Bytes
2eafbc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from typing import Generator, Iterable, List, TypeVar, Union
B = TypeVar("B")
def calculate_input_elements(input_value: Union[B, List[B]]) -> int:
return len(input_value) if issubclass(type(input_value), list) else 1
def create_batches(
sequence: Iterable[B], batch_size: int
) -> Generator[List[B], None, None]:
batch_size = max(batch_size, 1)
current_batch = []
for element in sequence:
if len(current_batch) == batch_size:
yield current_batch
current_batch = []
current_batch.append(element)
if len(current_batch) > 0:
yield current_batch
|