mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-05 13:17:39 +03:00
dd106d9534
Summary: ## What does this PR do? Some downstream users reported that errors when passing Namespace to load_checkpoint(). A recent change made the assumption that the passed object is dict like (dict or DictConfig) that have a get function. This changes that and make sure the mocked config have checkpoint.save_dir to allow the test to run. Pull Request resolved: https://github.com/pytorch/fairseq/pull/3675 Reviewed By: omry Differential Revision: D29564805 Pulled By: lematt1991 fbshipit-source-id: 89308811da382667f6c5d3152ee2d6480416ee62
248 lines
9.1 KiB
Python
248 lines
9.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import contextlib
|
|
import logging
|
|
import unittest
|
|
from io import StringIO
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
from fairseq import checkpoint_utils, data
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
def mock_trainer(epoch, num_updates, iterations_in_epoch):
|
|
trainer = MagicMock()
|
|
trainer.load_checkpoint.return_value = {
|
|
"train_iterator": {
|
|
"epoch": epoch,
|
|
"iterations_in_epoch": iterations_in_epoch,
|
|
"shuffle": False,
|
|
},
|
|
}
|
|
trainer.get_num_updates.return_value = num_updates
|
|
return trainer
|
|
|
|
|
|
def mock_dict():
|
|
d = MagicMock()
|
|
d.pad.return_value = 1
|
|
d.eos.return_value = 2
|
|
d.unk.return_value = 3
|
|
return d
|
|
|
|
|
|
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
|
|
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
|
|
tokens_ds = data.TokenBlockDataset(
|
|
tokens,
|
|
sizes=[tokens.size(-1)],
|
|
block_size=1,
|
|
pad=0,
|
|
eos=1,
|
|
include_targets=False,
|
|
)
|
|
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
|
|
dataset = data.LanguagePairDataset(
|
|
tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False
|
|
)
|
|
epoch_itr = data.EpochBatchIterator(
|
|
dataset=dataset,
|
|
collate_fn=dataset.collater,
|
|
batch_sampler=[[i] for i in range(epoch_size)],
|
|
)
|
|
return trainer, epoch_itr
|
|
|
|
|
|
def get_mock_cfg(finetune_from_model):
|
|
cfg_mock = OmegaConf.create(
|
|
{
|
|
"checkpoint": {
|
|
"save_dir": None,
|
|
"optimizer_overrides": "{}",
|
|
"reset_dataloader": False,
|
|
"reset_meters": False,
|
|
"reset_optimizer": False,
|
|
"reset_lr_scheduler": False,
|
|
"finetune_from_model": finetune_from_model,
|
|
"model_parallel_size": 1,
|
|
"restore_file": "checkpoint_last.pt",
|
|
},
|
|
"common": {
|
|
"model_parallel_size": 1,
|
|
},
|
|
}
|
|
)
|
|
return cfg_mock
|
|
|
|
|
|
class TestLoadCheckpoint(unittest.TestCase):
|
|
def setUp(self):
|
|
self.cfg_mock = get_mock_cfg(None)
|
|
self.patches = {
|
|
"os.makedirs": MagicMock(),
|
|
"os.path.join": MagicMock(),
|
|
"os.path.isfile": MagicMock(return_value=True),
|
|
"os.path.isabs": MagicMock(return_value=False),
|
|
"fairseq.file_io.PathManager.exists": MagicMock(return_value=False),
|
|
}
|
|
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
|
|
[p.start() for p in self.applied_patches]
|
|
logging.disable(logging.CRITICAL)
|
|
|
|
def tearDown(self):
|
|
patch.stopall()
|
|
logging.disable(logging.NOTSET)
|
|
|
|
def test_load_partial_checkpoint(self):
|
|
with contextlib.redirect_stdout(StringIO()):
|
|
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
|
|
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
|
|
|
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
|
self.cfg_mock.checkpoint, trainer
|
|
)
|
|
|
|
self.assertEqual(epoch_itr.epoch, 2)
|
|
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
|
|
|
|
itr = epoch_itr.next_epoch_itr(shuffle=False)
|
|
self.assertEqual(epoch_itr.epoch, 2)
|
|
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
|
|
|
|
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 50)
|
|
self.assertEqual(epoch_itr.iterations_in_epoch, 51)
|
|
|
|
for _ in range(150 - 52):
|
|
next(itr)
|
|
self.assertEqual(epoch_itr.iterations_in_epoch, 149)
|
|
self.assertTrue(itr.has_next())
|
|
next(itr)
|
|
self.assertFalse(itr.has_next())
|
|
|
|
itr = epoch_itr.next_epoch_itr(shuffle=False)
|
|
self.assertTrue(itr.has_next())
|
|
self.assertEqual(epoch_itr.epoch, 3)
|
|
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
|
|
|
|
def test_load_full_checkpoint(self):
|
|
with contextlib.redirect_stdout(StringIO()):
|
|
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
|
|
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
|
|
|
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
|
self.cfg_mock.checkpoint, trainer
|
|
)
|
|
itr = epoch_itr.next_epoch_itr(shuffle=False)
|
|
|
|
self.assertEqual(epoch_itr.epoch, 3)
|
|
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
|
|
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0)
|
|
|
|
def test_load_no_checkpoint(self):
|
|
with contextlib.redirect_stdout(StringIO()):
|
|
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
|
|
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
|
self.patches["os.path.isfile"].return_value = False
|
|
|
|
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
|
self.cfg_mock.checkpoint, trainer
|
|
)
|
|
itr = epoch_itr.next_epoch_itr(shuffle=False)
|
|
|
|
self.assertEqual(epoch_itr.epoch, 1)
|
|
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
|
|
self.assertEqual(next(itr)["net_input"]["src_tokens"][0].item(), 0)
|
|
|
|
def test_finetune_from_model_args_conflict(self):
|
|
with contextlib.redirect_stdout(StringIO()):
|
|
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
|
|
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
|
|
|
for arg in [
|
|
"reset_optimizer",
|
|
"reset_lr_scheduler",
|
|
"reset_meters",
|
|
"reset_dataloader",
|
|
]:
|
|
with self.subTest(arg=arg):
|
|
cfg_mock = get_mock_cfg("/temp/checkpoint_pretrained.pt")
|
|
cfg_mock["checkpoint"][arg] = True
|
|
with self.assertRaises(Exception) as context:
|
|
_, _ = checkpoint_utils.load_checkpoint(
|
|
cfg_mock.checkpoint, trainer
|
|
)
|
|
|
|
self.assertTrue(
|
|
"--finetune-from-model can not be set together with either --reset-optimizer"
|
|
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
|
in str(context.exception)
|
|
)
|
|
|
|
def test_finetune_from_model(self):
|
|
with contextlib.redirect_stdout(StringIO()):
|
|
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
|
|
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
|
from_model_path = "/temp/checkpoint_pretrained.pt"
|
|
|
|
def mock_finetune_exist(path):
|
|
if path == from_model_path:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
self.patches[
|
|
"fairseq.file_io.PathManager.exists"
|
|
].side_effect = mock_finetune_exist
|
|
cfg_mock = get_mock_cfg(from_model_path)
|
|
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
|
|
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
|
|
(
|
|
checkpoint_path,
|
|
reset_optimizer,
|
|
reset_lr_scheduler,
|
|
optimizer_overrides,
|
|
) = trainer.load_checkpoint.call_args[0]
|
|
reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"]
|
|
self.assertTrue(reset_optimizer)
|
|
self.assertTrue(reset_lr_scheduler)
|
|
self.assertTrue(reset_meters)
|
|
|
|
def test_finetune_from_model_resume(self):
|
|
with contextlib.redirect_stdout(StringIO()):
|
|
trainer, epoch_itr = get_trainer_and_epoch_itr(1, 150, 0, 0)
|
|
trainer.get_train_iterator = MagicMock(return_value=epoch_itr)
|
|
from_model_path = "/temp/checkpoint_pretrained.pt"
|
|
|
|
# launch second time
|
|
# both restore_file=checkpoint_last.pt and finetune_from_model are set
|
|
def mock_finetune_exist(path):
|
|
if path == from_model_path or path.endsWith("checkpoint_last.pt"):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
self.patches[
|
|
"fairseq.file_io.PathManager.exists"
|
|
].side_effect = mock_finetune_exist
|
|
cfg_mock = get_mock_cfg(from_model_path)
|
|
cfg_mock.checkpoint.restore_file = "checkpoint_last.pt"
|
|
_, _ = checkpoint_utils.load_checkpoint(cfg_mock.checkpoint, trainer)
|
|
(
|
|
checkpoint_path,
|
|
reset_optimizer,
|
|
reset_lr_scheduler,
|
|
optimizer_overrides,
|
|
) = trainer.load_checkpoint.call_args[0]
|
|
reset_meters = trainer.load_checkpoint.call_args[1]["reset_meters"]
|
|
self.assertFalse(reset_optimizer)
|
|
self.assertFalse(reset_lr_scheduler)
|
|
self.assertFalse(reset_meters)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|