File size: 1,412 Bytes
6bc49a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
# some utils for training
class BooksBatcherIter:
def __init__(self, data_iter, batch_size, tokenizer, chunk_size=1024):
self.data_iter = data_iter
self.batch_size = batch_size
self.chunk_size = chunk_size
self.batch_fns = [self._batch_fn()]
self.collate_fn = DataCollatorWithPadding(tokenizer)
def _batch_fn(self):
for book in self.data_iter:
for i in range(0, len(book), self.chunk_size):
yield book[i:i+self.chunk_size]
def __iter__(self) -> 'BooksBatcherIter':
return self
def __next__(self) -> Any:
batch = []
try:
for b in self.batch_fns:
batch.append(next(b))
except StopIteration:
raise StopIteration
return self.collate_fn(batch)
class BooksBatcher:
def __init__(self, dataset, batch_size, tokenizer) -> None:
self.batch_size = batch_size
self.tokenizer = tokenizer
self.dataloader = DataLoader(
dataset=dataset,
batch_size=None, # return raw samples
shuffle=True,
num_workers=2,
prefetch_factor=4
)
def __iter__(self) -> 'BooksBatcherIter':
return BooksBatcherIter(iter(self.dataloader), self.batch_size, self.tokenizer)
|