Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import contextlib | |
from io import StringIO | |
import unittest | |
from unittest.mock import MagicMock, patch | |
import torch | |
from fairseq import data, checkpoint_utils | |
def mock_trainer(epoch, num_updates, iterations_in_epoch): | |
trainer = MagicMock() | |
trainer.load_checkpoint.return_value = { | |
'train_iterator': { | |
'epoch': epoch, | |
'iterations_in_epoch': iterations_in_epoch, | |
'shuffle': False, | |
}, | |
} | |
trainer.get_num_updates.return_value = num_updates | |
return trainer | |
def mock_dict(): | |
d = MagicMock() | |
d.pad.return_value = 1 | |
d.eos.return_value = 2 | |
d.unk.return_value = 3 | |
return d | |
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): | |
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1) | |
tokens_ds = data.TokenBlockDataset( | |
tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False, | |
) | |
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) | |
dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False) | |
epoch_itr = data.EpochBatchIterator( | |
dataset=dataset, | |
collate_fn=dataset.collater, | |
batch_sampler=[[i] for i in range(epoch_size)], | |
) | |
return trainer, epoch_itr | |
class TestLoadCheckpoint(unittest.TestCase): | |
def setUp(self): | |
self.args_mock = MagicMock() | |
self.args_mock.optimizer_overrides = '{}' | |
self.args_mock.reset_dataloader = False | |
self.args_mock.reset_meters = False | |
self.args_mock.reset_optimizer = False | |
self.patches = { | |
'os.makedirs': MagicMock(), | |
'os.path.join': MagicMock(), | |
'os.path.isfile': MagicMock(return_value=True), | |
'os.path.isabs': MagicMock(return_value=False), | |
} | |
self.applied_patches = [patch(p, d) for p, d in self.patches.items()] | |
[p.start() for p in self.applied_patches] | |
def test_load_partial_checkpoint(self): | |
with contextlib.redirect_stdout(StringIO()): | |
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) | |
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) | |
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) | |
self.assertEqual(epoch_itr.epoch, 2) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 50) | |
itr = epoch_itr.next_epoch_itr(shuffle=False) | |
self.assertEqual(epoch_itr.epoch, 2) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 50) | |
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 51) | |
for _ in range(150 - 52): | |
next(itr) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 149) | |
self.assertTrue(itr.has_next()) | |
next(itr) | |
self.assertFalse(itr.has_next()) | |
itr = epoch_itr.next_epoch_itr(shuffle=False) | |
self.assertTrue(itr.has_next()) | |
self.assertEqual(epoch_itr.epoch, 3) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 0) | |
def test_load_full_checkpoint(self): | |
with contextlib.redirect_stdout(StringIO()): | |
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150) | |
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) | |
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) | |
itr = epoch_itr.next_epoch_itr(shuffle=False) | |
self.assertEqual(epoch_itr.epoch, 3) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 0) | |
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) | |
def test_load_no_checkpoint(self): | |
with contextlib.redirect_stdout(StringIO()): | |
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0) | |
trainer.get_train_iterator = MagicMock(return_value=epoch_itr) | |
self.patches['os.path.isfile'].return_value = False | |
_, epoch_itr = checkpoint_utils.load_checkpoint(self.args_mock, trainer) | |
itr = epoch_itr.next_epoch_itr(shuffle=False) | |
self.assertEqual(epoch_itr.epoch, 1) | |
self.assertEqual(epoch_itr.iterations_in_epoch, 0) | |
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0) | |
def tearDown(self): | |
patch.stopall() | |
if __name__ == '__main__': | |
unittest.main() | |