|
import os |
|
import json |
|
import torch |
|
import datasets |
|
from torch.utils.data import DataLoader, Dataset |
|
from transformers import PreTrainedTokenizerFast |
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, data, tokenizer, max_length=512): |
|
self.data = data |
|
self.tokenizer = tokenizer |
|
self.max_length = max_length |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
text = self.data[idx]["text"] |
|
inputs = self.tokenizer( |
|
text, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
return { |
|
"input_ids": inputs["input_ids"].squeeze(0), |
|
"attention_mask": inputs["attention_mask"].squeeze(0) |
|
} |
|
|
|
class DataLoaderHandler: |
|
def __init__(self, dataset_path, tokenizer_path, batch_size=8, max_length=512): |
|
self.dataset_path = dataset_path |
|
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path) |
|
self.batch_size = batch_size |
|
self.max_length = max_length |
|
|
|
def load_dataset(self): |
|
if self.dataset_path.endswith(".json"): |
|
with open(self.dataset_path, "r", encoding="utf-8") as f: |
|
data = json.load(f) |
|
elif self.dataset_path.endswith(".jsonl"): |
|
data = [json.loads(line) for line in open(self.dataset_path, "r", encoding="utf-8")] |
|
else: |
|
raise ValueError("Unsupported dataset format. Use JSON or JSONL.") |
|
return data |
|
|
|
def get_dataloader(self): |
|
data = self.load_dataset() |
|
dataset = CustomDataset(data, self.tokenizer, self.max_length) |
|
return DataLoader(dataset, batch_size=self.batch_size, shuffle=True) |
|
|
|
if __name__ == "__main__": |
|
dataset_path = "data/dataset.jsonl" |
|
tokenizer_path = "tokenizer.json" |
|
batch_size = 16 |
|
|
|
data_loader_handler = DataLoaderHandler(dataset_path, tokenizer_path, batch_size) |
|
dataloader = data_loader_handler.get_dataloader() |
|
|
|
for batch in dataloader: |
|
print(batch["input_ids"].shape, batch["attention_mask"].shape) |
|
break |