Charm_15 / data_loader.py
GeminiFan207's picture
Create data_loader.py
6b81dd1 verified
raw
history blame
2.22 kB
import os
import json
import torch
import datasets
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizerFast
class CustomDataset(Dataset):
def __init__(self, data, tokenizer, max_length=512):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text = self.data[idx]["text"]
inputs = self.tokenizer(
text,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
return {
"input_ids": inputs["input_ids"].squeeze(0),
"attention_mask": inputs["attention_mask"].squeeze(0)
}
class DataLoaderHandler:
def __init__(self, dataset_path, tokenizer_path, batch_size=8, max_length=512):
self.dataset_path = dataset_path
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
self.batch_size = batch_size
self.max_length = max_length
def load_dataset(self):
if self.dataset_path.endswith(".json"):
with open(self.dataset_path, "r", encoding="utf-8") as f:
data = json.load(f)
elif self.dataset_path.endswith(".jsonl"):
data = [json.loads(line) for line in open(self.dataset_path, "r", encoding="utf-8")]
else:
raise ValueError("Unsupported dataset format. Use JSON or JSONL.")
return data
def get_dataloader(self):
data = self.load_dataset()
dataset = CustomDataset(data, self.tokenizer, self.max_length)
return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
if __name__ == "__main__":
dataset_path = "data/dataset.jsonl" # Update with actual dataset path
tokenizer_path = "tokenizer.json" # Update with actual tokenizer path
batch_size = 16
data_loader_handler = DataLoaderHandler(dataset_path, tokenizer_path, batch_size)
dataloader = data_loader_handler.get_dataloader()
for batch in dataloader:
print(batch["input_ids"].shape, batch["attention_mask"].shape)
break