|
|
|
""" |
|
Data module for Cosmopedia dataset |
|
Author: Shilpaj Bhalerao |
|
Date: 2025-01-20 |
|
""" |
|
|
|
from typing import Optional |
|
|
|
|
|
import pytorch_lightning as pl |
|
from torch.utils.data import DataLoader |
|
from datasets import load_dataset |
|
from transformers import GPT2Tokenizer |
|
|
|
|
|
from config import DataConfig |
|
|
|
|
|
class CosmopediaDataModule(pl.LightningDataModule): |
|
""" |
|
Data module for Cosmopedia dataset |
|
""" |
|
def __init__( |
|
self, |
|
batch_size: int = DataConfig.batch_size, |
|
num_workers: int = DataConfig.num_workers, |
|
shuffle_buffer_size: int = DataConfig.shuffle_buffer_size, |
|
max_length: int = DataConfig.max_length, |
|
): |
|
""" |
|
Constructor |
|
:param batch_size: Batch size for dataloaders |
|
:param num_workers: Number of workers for dataloaders |
|
:param shuffle_buffer_size: Size of buffer for shuffling streaming data |
|
:param max_length: Maximum sequence length for tokenized text |
|
""" |
|
super().__init__() |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.shuffle_buffer_size = shuffle_buffer_size |
|
self.max_length = max_length |
|
|
|
|
|
self.dataset_path = DataConfig.dataset_path |
|
self.dataset_name = DataConfig.dataset_name |
|
|
|
|
|
self.tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path) |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
def setup(self, stage: Optional[str] = None): |
|
""" |
|
Setup datasets for training and validation |
|
""" |
|
|
|
self.dataset = load_dataset( |
|
self.dataset_path, |
|
self.dataset_name, |
|
split="train", |
|
streaming=DataConfig.streaming |
|
) |
|
|
|
|
|
self.dataset = self.dataset.shuffle(buffer_size=self.shuffle_buffer_size) |
|
|
|
|
|
val_size = int(DataConfig.validation_split * self.shuffle_buffer_size) |
|
self.train_dataset = self.dataset.skip(val_size) |
|
self.val_dataset = self.dataset.take(val_size) |
|
|
|
def collate_fn(self, batch): |
|
""" |
|
Tokenize and pad the texts in the batch |
|
""" |
|
texts = [item['text'] for item in batch] |
|
|
|
|
|
encodings = self.tokenizer( |
|
texts, |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors='pt' |
|
) |
|
|
|
|
|
input_ids = encodings['input_ids'][:, :-1] |
|
labels = encodings['input_ids'][:, 1:] |
|
attention_mask = encodings['attention_mask'][:, :-1] |
|
|
|
return { |
|
'input_ids': input_ids, |
|
'labels': labels, |
|
'attention_mask': attention_mask |
|
} |
|
|
|
def train_dataloader(self): |
|
""" |
|
Return train dataloader |
|
""" |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
pin_memory=DataConfig.pin_memory, |
|
collate_fn=self.collate_fn |
|
) |
|
|
|
def val_dataloader(self): |
|
""" |
|
Return validation dataloader |
|
""" |
|
return DataLoader( |
|
self.val_dataset, |
|
batch_size=self.batch_size, |
|
num_workers=self.num_workers, |
|
pin_memory=DataConfig.pin_memory, |
|
collate_fn=self.collate_fn |
|
) |
|
|