import os import json import torch import datasets from torch.utils.data import DataLoader, Dataset from transformers import PreTrainedTokenizerFast class CustomDataset(Dataset): def __init__(self, data, tokenizer, max_length=512): self.data = data self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.data) def __getitem__(self, idx): text = self.data[idx]["text"] inputs = self.tokenizer( text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt" ) return { "input_ids": inputs["input_ids"].squeeze(0), "attention_mask": inputs["attention_mask"].squeeze(0) } class DataLoaderHandler: def __init__(self, dataset_path, tokenizer_path, batch_size=8, max_length=512): self.dataset_path = dataset_path self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path) self.batch_size = batch_size self.max_length = max_length def load_dataset(self): if self.dataset_path.endswith(".json"): with open(self.dataset_path, "r", encoding="utf-8") as f: data = json.load(f) elif self.dataset_path.endswith(".jsonl"): data = [json.loads(line) for line in open(self.dataset_path, "r", encoding="utf-8")] else: raise ValueError("Unsupported dataset format. Use JSON or JSONL.") return data def get_dataloader(self): data = self.load_dataset() dataset = CustomDataset(data, self.tokenizer, self.max_length) return DataLoader(dataset, batch_size=self.batch_size, shuffle=True) if __name__ == "__main__": dataset_path = "data/dataset.jsonl" # Update with actual dataset path tokenizer_path = "tokenizer.json" # Update with actual tokenizer path batch_size = 16 data_loader_handler = DataLoaderHandler(dataset_path, tokenizer_path, batch_size) dataloader = data_loader_handler.get_dataloader() for batch in dataloader: print(batch["input_ids"].shape, batch["attention_mask"].shape) break