File size: 3,777 Bytes
a1fcadc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
"""
Data module for Cosmopedia dataset
Author: Shilpaj Bhalerao
Date: 2025-01-20
"""
# Standard Library Imports
from typing import Optional

# Third-Party Imports
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import GPT2Tokenizer

# Local Imports
from config import DataConfig


class CosmopediaDataModule(pl.LightningDataModule):
    """
    Data module for Cosmopedia dataset
    """
    def __init__(
        self,
        batch_size: int = DataConfig.batch_size,
        num_workers: int = DataConfig.num_workers,
        shuffle_buffer_size: int = DataConfig.shuffle_buffer_size,
        max_length: int = DataConfig.max_length,
    ):
        """
        Constructor
        :param batch_size: Batch size for dataloaders
        :param num_workers: Number of workers for dataloaders
        :param shuffle_buffer_size: Size of buffer for shuffling streaming data
        :param max_length: Maximum sequence length for tokenized text
        """
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.shuffle_buffer_size = shuffle_buffer_size
        self.max_length = max_length
        
        # Dataset path on HuggingFace
        self.dataset_path = DataConfig.dataset_path
        self.dataset_name = DataConfig.dataset_name
        
        # Initialize tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def setup(self, stage: Optional[str] = None):
        """
        Setup datasets for training and validation
        """
        # Load dataset in streaming mode
        self.dataset = load_dataset(
            self.dataset_path,
            self.dataset_name,
            split="train",  # Only train split is available
            streaming=DataConfig.streaming
        )
        
        # Shuffle the streaming dataset
        self.dataset = self.dataset.shuffle(buffer_size=self.shuffle_buffer_size)
        
        # Create train/val split using configured validation split
        val_size = int(DataConfig.validation_split * self.shuffle_buffer_size)
        self.train_dataset = self.dataset.skip(val_size)
        self.val_dataset = self.dataset.take(val_size)

    def collate_fn(self, batch):
        """
        Tokenize and pad the texts in the batch
        """
        texts = [item['text'] for item in batch]
        
        # Tokenize all texts in the batch
        encodings = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Prepare inputs and labels for language modeling
        input_ids = encodings['input_ids'][:, :-1]
        labels = encodings['input_ids'][:, 1:]
        attention_mask = encodings['attention_mask'][:, :-1]
        
        return {
            'input_ids': input_ids,
            'labels': labels,
            'attention_mask': attention_mask
        }

    def train_dataloader(self):
        """
        Return train dataloader
        """
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=DataConfig.pin_memory,
            collate_fn=self.collate_fn
        )

    def val_dataloader(self):
        """
        Return validation dataloader
        """
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=DataConfig.pin_memory,
            collate_fn=self.collate_fn
        )