Spaces:
Running
Running
from typing import List, TextIO, Dict, Optional | |
import torch | |
from torch.utils.data import IterableDataset | |
from torch.utils.data.dataset import T_co | |
def blocks(files, size=65536): | |
while True: | |
b = files.read(size) | |
if not b: | |
break | |
yield b | |
def count_lines(input_path: str) -> int: | |
with open(input_path, "r", encoding="utf8") as f: | |
return sum(bl.count("\n") for bl in blocks(f)) | |
class DatasetReader(IterableDataset): | |
def __init__(self, filename, tokenizer, max_length=128): | |
self.filename = filename | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
def preprocess(self, text: str): | |
return self.tokenizer( | |
text.rstrip().strip(), | |
padding="max_length", | |
truncation=True, | |
max_length=self.max_length, | |
return_tensors="pt", | |
) | |
def __iter__(self): | |
file_itr = open(self.filename, "r") | |
mapped_itr = map(self.preprocess, file_itr) | |
return mapped_itr | |
def collate_function(batch: List[T_co]) -> Dict[str, torch.Tensor]: | |
return { | |
"input_ids": torch.stack([item["input_ids"][0] for item in batch]), | |
"attention_mask": torch.stack([item["attention_mask"][0] for item in batch]), | |
} | |
def get_dataloader( | |
filename: str, tokenizer: str, batch_size: int, max_length: int | |
) -> torch.utils.data.DataLoader: | |
dataset = DatasetReader(filename, tokenizer, max_length) | |
return torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
collate_fn=collate_function, | |
) | |