# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import contextlib from io import StringIO import unittest from unittest.mock import MagicMock, patch import torch from fairseq import data, checkpoint_utils def mock_trainer(epoch, num_updates, iterations_in_epoch): trainer = MagicMock() trainer.load_checkpoint.return_value = { 'train_iterator': { 'epoch': epoch, 'iterations_in_epoch': iterations_in_epoch, 'shuffle': False, }, } trainer.get_num_updates.return_value = num_updates return trainer def mock_dict(): d = MagicMock() d.pad.return_value = 1 d.eos.return_value = 2 d.unk.return_value = 3 return d def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1) tokens_ds = data.TokenBlockDataset( tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False, ) trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False) epoch_itr = data.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=[[i] for i in range(epoch_size)], ) return trainer, epoch_itr class TestLoadCheckpoint(unittest.TestCase): def setUp(self): self.args_mock = MagicMock() self.args_mock.optimizer_overrides = '{}' self.args_mock.reset_dataloader = False self.args_mock.reset_meters = False self.args_mock.reset_optimizer = False self.patches = { 'os.makedirs': MagicMock(), 'os.path.join': MagicMock(), 'os.path.isfile': MagicMock(return_value=True), 'os.path.isabs': MagicMock(return_value=False), } self.applied_patches = [patch(p, d) for p, d in self.patches.items()] [p.start() for p in self.applied_patches] def test_load_partial_checkpoint(self): with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.iterations_in_epoch, 50) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 2) self.assertEqual(epoch_itr.iterations_in_epoch, 50) self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50) self.assertEqual(epoch_itr.iterations_in_epoch, 51) for _ in range(150 - 52): next(itr) self.assertEqual(epoch_itr.iterations_in_epoch, 149) self.assertTrue(itr.has_next()) next(itr) self.assertFalse(itr.has_next()) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertTrue(itr.has_next()) self.assertEqual(epoch_itr.epoch, 3) self.assertEqual(epoch_itr.iterations_in_epoch, 0) def test_load_full_checkpoint(self): with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 3) self.assertEqual(epoch_itr.iterations_in_epoch, 0) self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) def test_load_no_checkpoint(self): with contextlib.redirect_stdout(StringIO()): trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0) trainer.get_train_iterator = MagicMock(return_value=epoch_itr) self.patches['os.path.isfile'].return_value = False _, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) itr = epoch_itr.next_epoch_itr(shuffle=False) self.assertEqual(epoch_itr.epoch, 1) self.assertEqual(epoch_itr.iterations_in_epoch, 0) self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) def tearDown(self): patch.stopall() if __name__ == '__main__': unittest.main()