fairseq/tests/test_train.py
Omry Yadan dd106d9534 fixes tests/test_train.py to mock checkpoint.save_dir config node (#3675)
Summary:
## What does this PR do?
Some downstream users reported that errors when passing Namespace to load_checkpoint().

A recent change made the assumption that the passed object is dict like (dict or DictConfig) that have a get function.
This changes that and make sure the mocked config have checkpoint.save_dir to allow the test to run.

Pull Request resolved: https://github.com/pytorch/fairseq/pull/3675

Reviewed By: omry

Differential Revision: D29564805

Pulled By: lematt1991

fbshipit-source-id: 89308811da382667f6c5d3152ee2d6480416ee62
2021-07-06 15:07:31 -07:00

248 lines
9.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import unittest
from io import StringIO
from unittest.mock import MagicMock, patch
import torch
from fairseq import checkpoint_utils, data
from omegaconf import OmegaConf
def mock_trainer(epoch, num_updates, iterations_in_epoch):
trainer = MagicMock()
trainer.load_checkpoint.return_value = {
"train_iterator": {
"epoch": epoch,
"iterations_in_epoch": iterations_in_epoch,
"shuffle": False,
},
}
trainer.get_num_updates.return_value = num_updates
return trainer
def mock_dict():
d = MagicMock()
d.pad.return_value = 1
d.eos.return_value = 2
d.unk.return_value = 3
return d
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
tokens_ds = data.TokenBlockDataset(
tokens,
sizes=[tokens.size(-1)],
block_size=1,
pad=0,
eos=1,
include_targets=False,
)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(
tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False
)
epoch_itr = data.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=[[i] for i in range(epoch_size)],
)
return trainer, epoch_itr
def get_mock_cfg(finetune_from_model):
cfg_mock = OmegaConf.create(
{
"checkpoint": {
"save_dir": None,
"optimizer_overrides": "{}",
"reset_dataloader": False,
"reset_meters": False,
"reset_optimizer": False,
"reset_lr_scheduler": False,
"finetune_from_model": finetune_from_model,
"model_parallel_size": 1,
"restore_file": "checkpoint_last.pt",
},
"common": {
"model_parallel_size": 1,
},
}
)
return cfg_mock
class TestLoadCheckpoint(unittest.TestCase):
def setUp(self):
self.cfg_mock = get_mock_cfg(None)
self.patches = {
"os.makedirs": MagicMock(),
"os.path.join": MagicMock(),
"os.path.isfile": MagicMock(return_value=True),
"os.path.isabs": MagicMock(return_value=False),
"fairseq.file_io.PathManager.exists": MagicMock(return_value=False),
}
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
logging.disable(logging.CRITICAL)
def tearDown(self):
patch.stopall()
logging.disable(logging.NOTSET)
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
_, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, trainer
)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 50)
self.assertEqual(epoch_itr.iterations_in_epoch, 51)
for _ in range(150 - 52):
next(itr)
self.assertEqual(epoch_itr.iterations_in_epoch, 149)
self.assertTrue(itr.has_next())
next(itr)
self.assertFalse(itr.has_next())
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertTrue(itr.has_next())
self.assertEqual(epoch_itr.epoch, 3)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
def test_load_full_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
_, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, trainer
)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0)
def test_load_no_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
self.patches["os.path.isfile"].return_value = False
_, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, trainer
)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0)
def test_finetune_from_model_args_conflict(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
for arg in [
"reset_optimizer",
"reset_lr_scheduler",
"reset_meters",
"reset_dataloader",
]:
with self.subTest(arg=arg):
cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt")
cfg_mock["checkpoint"][arg] = True
with self.assertRaises(Exception) as context:
_, _ = checkpoint_utils.load_checkpoint(
cfg_mock.checkpoint, trainer
)
self.assertTrue(
"--finetune-from-model can not be set together with either --reset-optimizer"
" or reset_lr_scheduler or reset_meters or reset_dataloader"
in str(context.exception)
)
def test_finetune_from_model(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
from_model_path = "/temp/checkpoint_pretrained.pt"
def mock_finetune_exist(path):
if path == from_model_path:
return True
else:
return False
self.patches[
"fairseq.file_io.PathManager.exists"
].side_effect = mock_finetune_exist
cfg_mock = get_mock_cfg(from_model_path)
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
(
checkpoint_path,
reset_optimizer,
reset_lr_scheduler,
optimizer_overrides,
) = trainer.load_checkpoint.call_args[0]
reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"]
self.assertTrue(reset_optimizer)
self.assertTrue(reset_lr_scheduler)
self.assertTrue(reset_meters)
def test_finetune_from_model_resume(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
from_model_path = "/temp/checkpoint_pretrained.pt"
# launch second time
# both restore_file=checkpoint_last.pt and finetune_from_model are set
def mock_finetune_exist(path):
if path == from_model_path or path.endsWith("checkpoint_last.pt"):
return True
else:
return False
self.patches[
"fairseq.file_io.PathManager.exists"
].side_effect = mock_finetune_exist
cfg_mock = get_mock_cfg(from_model_path)
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
(
checkpoint_path,
reset_optimizer,
reset_lr_scheduler,
optimizer_overrides,
) = trainer.load_checkpoint.call_args[0]
reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"]
self.assertFalse(reset_optimizer)
self.assertFalse(reset_lr_scheduler)
self.assertFalse(reset_meters)
if __name__ == "__main__":
unittest.main()