File size: 2,223 Bytes
6b81dd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import json
import torch
import datasets
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizerFast

class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["text"]
        inputs = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0)
        }

class DataLoaderHandler:
    def __init__(self, dataset_path, tokenizer_path, batch_size=8, max_length=512):
        self.dataset_path = dataset_path
        self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
        self.batch_size = batch_size
        self.max_length = max_length

    def load_dataset(self):
        if self.dataset_path.endswith(".json"):
            with open(self.dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
        elif self.dataset_path.endswith(".jsonl"):
            data = [json.loads(line) for line in open(self.dataset_path, "r", encoding="utf-8")]
        else:
            raise ValueError("Unsupported dataset format. Use JSON or JSONL.")
        return data

    def get_dataloader(self):
        data = self.load_dataset()
        dataset = CustomDataset(data, self.tokenizer, self.max_length)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

if __name__ == "__main__":
    dataset_path = "data/dataset.jsonl"  # Update with actual dataset path
    tokenizer_path = "tokenizer.json"    # Update with actual tokenizer path
    batch_size = 16

    data_loader_handler = DataLoaderHandler(dataset_path, tokenizer_path, batch_size)
    dataloader = data_loader_handler.get_dataloader()

    for batch in dataloader:
        print(batch["input_ids"].shape, batch["attention_mask"].shape)
        break