|  | """Module for testing streaming dataset sequence packing""" | 
					
						
						|  | import functools | 
					
						
						|  | import unittest | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from datasets import load_dataset | 
					
						
						|  | from torch.utils.data import DataLoader | 
					
						
						|  | from transformers import AutoTokenizer | 
					
						
						|  |  | 
					
						
						|  | from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset | 
					
						
						|  | from axolotl.utils.dict import DictDefault | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TestPretrainingPacking(unittest.TestCase): | 
					
						
						|  | """ | 
					
						
						|  | Test class for packing streaming dataset sequences | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def setUp(self) -> None: | 
					
						
						|  |  | 
					
						
						|  | self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") | 
					
						
						|  | self.tokenizer.pad_token = "</s>" | 
					
						
						|  |  | 
					
						
						|  | def test_packing_stream_dataset(self): | 
					
						
						|  |  | 
					
						
						|  | dataset = load_dataset( | 
					
						
						|  | "c4", | 
					
						
						|  | "en", | 
					
						
						|  | streaming=True, | 
					
						
						|  | )["train"] | 
					
						
						|  |  | 
					
						
						|  | cfg = DictDefault( | 
					
						
						|  | { | 
					
						
						|  | "pretraining_dataset": [ | 
					
						
						|  | { | 
					
						
						|  | "path": "c4", | 
					
						
						|  | "name": "en", | 
					
						
						|  | "type": "pretrain", | 
					
						
						|  | } | 
					
						
						|  | ], | 
					
						
						|  | "sample_packing": True, | 
					
						
						|  | "pad_to_sequence_len": True, | 
					
						
						|  | "sequence_len": 2048, | 
					
						
						|  | "micro_batch_size": 2, | 
					
						
						|  | } | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | ds_wrapper_partial = functools.partial( | 
					
						
						|  | get_dataset_wrapper, | 
					
						
						|  | cfg.pretraining_dataset[0], | 
					
						
						|  | self.tokenizer, | 
					
						
						|  | cfg, | 
					
						
						|  | cfg.pretraining_dataset[0]["type"] or "pretrain", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | original_bsz = cfg.micro_batch_size | 
					
						
						|  | train_dataset = wrap_pretraining_dataset( | 
					
						
						|  | dataset, | 
					
						
						|  | self.tokenizer, | 
					
						
						|  | cfg, | 
					
						
						|  | ds_wrapper_partial, | 
					
						
						|  | max_tokens=cfg.sequence_len, | 
					
						
						|  | batch_size=cfg.micro_batch_size, | 
					
						
						|  | seed=cfg.seed or 42, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | trainer_loader = DataLoader( | 
					
						
						|  | train_dataset, | 
					
						
						|  | batch_size=1, | 
					
						
						|  | collate_fn=None, | 
					
						
						|  | drop_last=True, | 
					
						
						|  | ) | 
					
						
						|  | idx = 0 | 
					
						
						|  | for data in trainer_loader: | 
					
						
						|  | if idx > 10: | 
					
						
						|  | break | 
					
						
						|  | assert data["input_ids"].shape == torch.Size( | 
					
						
						|  | [1, original_bsz * cfg.sequence_len] | 
					
						
						|  | ) | 
					
						
						|  | assert data["position_ids"].shape == torch.Size( | 
					
						
						|  | [1, original_bsz * cfg.sequence_len] | 
					
						
						|  | ) | 
					
						
						|  | assert data["labels"].shape == torch.Size( | 
					
						
						|  | [1, original_bsz * cfg.sequence_len] | 
					
						
						|  | ) | 
					
						
						|  | assert data["attention_mask"].shape == torch.Size( | 
					
						
						|  | [1, original_bsz * cfg.sequence_len] | 
					
						
						|  | ) | 
					
						
						|  | idx += 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | unittest.main() | 
					
						
						|  |  |