GeminiFan207 commited on
Commit
6b81dd1
·
verified ·
1 Parent(s): 7d82dd3

Create data_loader.py

Browse files
Files changed (1) hide show
  1. data_loader.py +63 -0
data_loader.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import datasets
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from transformers import PreTrainedTokenizerFast
7
+
8
+ class CustomDataset(Dataset):
9
+ def __init__(self, data, tokenizer, max_length=512):
10
+ self.data = data
11
+ self.tokenizer = tokenizer
12
+ self.max_length = max_length
13
+
14
+ def __len__(self):
15
+ return len(self.data)
16
+
17
+ def __getitem__(self, idx):
18
+ text = self.data[idx]["text"]
19
+ inputs = self.tokenizer(
20
+ text,
21
+ max_length=self.max_length,
22
+ padding="max_length",
23
+ truncation=True,
24
+ return_tensors="pt"
25
+ )
26
+ return {
27
+ "input_ids": inputs["input_ids"].squeeze(0),
28
+ "attention_mask": inputs["attention_mask"].squeeze(0)
29
+ }
30
+
31
+ class DataLoaderHandler:
32
+ def __init__(self, dataset_path, tokenizer_path, batch_size=8, max_length=512):
33
+ self.dataset_path = dataset_path
34
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
35
+ self.batch_size = batch_size
36
+ self.max_length = max_length
37
+
38
+ def load_dataset(self):
39
+ if self.dataset_path.endswith(".json"):
40
+ with open(self.dataset_path, "r", encoding="utf-8") as f:
41
+ data = json.load(f)
42
+ elif self.dataset_path.endswith(".jsonl"):
43
+ data = [json.loads(line) for line in open(self.dataset_path, "r", encoding="utf-8")]
44
+ else:
45
+ raise ValueError("Unsupported dataset format. Use JSON or JSONL.")
46
+ return data
47
+
48
+ def get_dataloader(self):
49
+ data = self.load_dataset()
50
+ dataset = CustomDataset(data, self.tokenizer, self.max_length)
51
+ return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
52
+
53
+ if __name__ == "__main__":
54
+ dataset_path = "data/dataset.jsonl" # Update with actual dataset path
55
+ tokenizer_path = "tokenizer.json" # Update with actual tokenizer path
56
+ batch_size = 16
57
+
58
+ data_loader_handler = DataLoaderHandler(dataset_path, tokenizer_path, batch_size)
59
+ dataloader = data_loader_handler.get_dataloader()
60
+
61
+ for batch in dataloader:
62
+ print(batch["input_ids"].shape, batch["attention_mask"].shape)
63
+ break