Spaces:
Runtime error
Runtime error
| import unittest | |
| from dataclasses import dataclass, is_dataclass | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import AutoTokenizer | |
| from trlx.pipeline import MiniBatchIterator | |
| from trlx.pipeline.offline_pipeline import ( | |
| ILQLRolloutStorage, | |
| ILQLSeq2SeqRolloutStorage, | |
| PromptPipeline, | |
| ) | |
| class DataclassBatch: | |
| query_tensors: torch.Tensor | |
| response_tensors: torch.Tensor | |
| logprobs: torch.Tensor | |
| values: torch.Tensor | |
| rewards: torch.Tensor | |
| class DummyDataset(Dataset, DataclassBatch): | |
| def __init__(self, num_samples): | |
| self.query_tensors = torch.randn(num_samples, 64) | |
| self.response_tensors = torch.randn(num_samples, 64) | |
| self.logprobs = torch.randn(num_samples, 1) | |
| self.values = torch.randn(num_samples, 1) | |
| self.rewards = torch.randn(num_samples, 1) | |
| def __len__(self): | |
| return len(self.query_tensors) | |
| def __getitem__(self, idx) -> DataclassBatch: | |
| return DataclassBatch( | |
| query_tensors=self.query_tensors[idx], | |
| response_tensors=self.response_tensors[idx], | |
| logprobs=self.logprobs[idx], | |
| values=self.values[idx], | |
| rewards=self.rewards[idx], | |
| ) | |
| def collate_fn(batch): | |
| return DataclassBatch( | |
| query_tensors=torch.stack([sample.query_tensors for sample in batch]), | |
| response_tensors=torch.stack([sample.response_tensors for sample in batch]), | |
| logprobs=torch.stack([sample.logprobs for sample in batch]), | |
| values=torch.stack([sample.values for sample in batch]), | |
| rewards=torch.stack([sample.rewards for sample in batch]), | |
| ) | |
| class BaseTestMiniBatchIterator(unittest.TestCase): | |
| def check_mini_batch(self, mb, expected_mini_batch_size): | |
| if is_dataclass(mb): | |
| mb = mb.__dict__ | |
| for key, value in mb.items(): | |
| self.assertEqual(value.size(0), expected_mini_batch_size) | |
| class TestMiniBatchDL(BaseTestMiniBatchIterator): | |
| def test_batch(self): | |
| batch = DataclassBatch( | |
| torch.tensor([1]), torch.tensor([2]), torch.tensor([3]), torch.tensor([4]), torch.tensor([5]) | |
| ) | |
| self.assertTrue(is_dataclass(batch)) | |
| self.assertTrue(all(isinstance(v, torch.Tensor) for v in batch.__dict__.values())) | |
| def test_minibatch_iterator(self): | |
| # Create Dummy Dataset and DataLoader | |
| dummy_dataset = DummyDataset(32) | |
| dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) | |
| iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2) | |
| for minibatches in iterator: | |
| for minibatch in minibatches: | |
| self.assertIsInstance(minibatch, DataclassBatch) | |
| self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
| self.check_mini_batch(minibatch, 4) | |
| def test_minibatch_iterator_with_undivisible_mbsize(self): | |
| # Create Dummy Dataset and DataLoader | |
| dummy_dataset = DummyDataset(32) | |
| dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) | |
| iterator = MiniBatchIterator(dummy_dataloader, mb_size=3, num_mb=3) | |
| for minibatches in iterator: | |
| for minibatch in minibatches[:-1]: | |
| self.assertIsInstance(minibatch, DataclassBatch) | |
| self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
| self.check_mini_batch(minibatch, 3) | |
| # last minibatch has only 2 samples | |
| minibatch = minibatches[-1] | |
| self.assertIsInstance(minibatch, DataclassBatch) | |
| self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
| self.check_mini_batch(minibatch, 2) | |
| def test_minibatch_iterator_with_remainder(self): | |
| # Create Dummy Dataset and DataLoader | |
| dummy_dataset = DummyDataset(36) | |
| dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) | |
| iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4) | |
| for i in range(4): | |
| minibatches = next(iterator) | |
| for minibatch in minibatches[:-1]: | |
| self.assertIsInstance(minibatch, DataclassBatch) | |
| self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
| self.check_mini_batch(minibatch, 2) | |
| # last iteration has only 2 minibatches | |
| minibatches = next(iterator) | |
| self.assertEqual(len(minibatches), 2) | |
| for minibatch in minibatches: | |
| self.assertIsInstance(minibatch, DataclassBatch) | |
| self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
| self.check_mini_batch(minibatch, 2) | |
| def test_minibatch_iterator_with_smaller_dataset(self): | |
| # Create Dummy Dataset and DataLoader with size smaller than batch size | |
| dummy_dataset = DummyDataset(6) | |
| dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) | |
| iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4) | |
| minibatches = next(iterator) | |
| for minibatch in minibatches: | |
| self.assertIsInstance(minibatch, DataclassBatch) | |
| self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) | |
| with self.assertRaises(StopIteration): | |
| minibatches = next(iterator) | |
| def test_minibatch_content(self): | |
| dummy_dataset = DummyDataset(32) | |
| dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn) | |
| iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2) | |
| idx = 0 | |
| for minibatches in iterator: | |
| for minibatch in minibatches: | |
| for key in minibatch.__dict__.keys(): | |
| original_data = getattr(dummy_dataset, key) | |
| start_idx = idx * minibatch.__dict__[key].size(0) | |
| end_idx = start_idx + minibatch.__dict__[key].size(0) | |
| expected_data = original_data[start_idx:end_idx] | |
| # Check if the tensor content in the minibatch is consistent with the original dataset | |
| self.assertTrue(torch.all(torch.eq(minibatch.__dict__[key], expected_data))) | |
| idx += 1 | |
| # Test if the iterator covered all the samples in the dataset | |
| self.assertEqual(idx * iterator.mb_size, len(dummy_dataset)) | |
| class TestMiniBatchIteratorWithPromptPipeline(BaseTestMiniBatchIterator): | |
| def test_minibatch_iterator_with_prompt_pipeline(self): | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| # Create prompts | |
| prompts = ["This is a test prompt."] * 32 | |
| prompt_pipeline = PromptPipeline(prompts, max_prompt_length=20, tokenizer=tokenizer) | |
| prompt_dataloader = prompt_pipeline.create_loader(batch_size=8, shuffle=True) | |
| iterator = MiniBatchIterator(prompt_dataloader, mb_size=4, num_mb=2) | |
| for minibatches in iterator: | |
| for minibatch in minibatches: | |
| self.assertTrue("input_ids" in minibatch) | |
| self.assertTrue("attention_mask" in minibatch) | |
| self.assertTrue(isinstance(minibatch["input_ids"], torch.Tensor)) | |
| self.assertTrue(isinstance(minibatch["attention_mask"], torch.Tensor)) | |
| self.check_mini_batch(minibatch, 4) | |
| class TestMiniBatchIteratorWithILQLRollouts(BaseTestMiniBatchIterator): | |
| def create_dummy_tensors(self, num_samples): | |
| input_ids = torch.randint(0, 100, (num_samples, 10)) | |
| attention_mask = torch.randint(0, 2, (num_samples, 10)) | |
| rewards = torch.randn(num_samples, 1) | |
| states_ixs = torch.randint(0, 100, (num_samples, 1)) | |
| actions_ixs = torch.randint(0, 100, (num_samples, 1)) | |
| dones = torch.randint(0, 2, (num_samples, 1), dtype=torch.bool) | |
| return input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones | |
| def test_minibatch_iterator_with_ilql_rollout_storage(self): | |
| # Create dummy data | |
| input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32) | |
| # Create ILQLRolloutStorage instance | |
| ilql_rollout_storage = ILQLRolloutStorage(input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones) | |
| ilql_dataloader = ilql_rollout_storage.create_loader(batch_size=8) | |
| iterator = MiniBatchIterator(ilql_dataloader, mb_size=4, num_mb=2) | |
| for minibatches in iterator: | |
| self.assertEqual(len(minibatches), 2) | |
| for minibatch in minibatches: | |
| self.check_mini_batch(minibatch, expected_mini_batch_size=4) | |
| def test_minibatch_iterator_with_ilql_seq2seq_rollout_storage(self): | |
| # Create dummy data | |
| input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32) | |
| decoder_input_ids = torch.randint(0, 100, (32, 10)) | |
| # Create ILQLSeq2SeqRolloutStorage instance | |
| ilql_seq2seq_rollout_storage = ILQLSeq2SeqRolloutStorage( | |
| input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones | |
| ) | |
| ilql_seq2seq_dataloader = ilql_seq2seq_rollout_storage.create_loader(batch_size=8) | |
| iterator = MiniBatchIterator(ilql_seq2seq_dataloader, mb_size=4, num_mb=2) | |
| for minibatches in iterator: | |
| self.assertEqual(len(minibatches), 2) | |
| for minibatch in minibatches: | |
| self.check_mini_batch(minibatch, expected_mini_batch_size=4) | |
| if __name__ == "__main__": | |
| unittest.main() | |