mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-07-14 18:50:22 +03:00
initial revision (#5328)
This commit is contained in:
parent
b5d89cddc9
commit
e29f53bfea
@ -104,7 +104,20 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
||||
"checkpoint_last{}.pt".format(suffix)
|
||||
] = not cfg.no_last_checkpoints
|
||||
|
||||
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
||||
extra_state = {
|
||||
"train_iterator": epoch_itr.state_dict(),
|
||||
"val_loss": val_loss,
|
||||
}
|
||||
|
||||
# Going forward, different tasks could expose an API like this to dump all
|
||||
# the checkpoint worthy attributes in a dictionary which then will be
|
||||
# merged with the parent dictionary to create the "extra_state". This
|
||||
# allows for an extensible yet simple design to checkpoint task level
|
||||
# attributes
|
||||
if hasattr(trainer.task, "get_checkpoint_dict"):
|
||||
extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
|
||||
logger.info(f"{trainer.task.__class__} checkpoint worthy attributes are ready to be persisted with the checkpoint")
|
||||
|
||||
if hasattr(save_checkpoint, "best"):
|
||||
extra_state.update({"best": save_checkpoint.best})
|
||||
|
||||
@ -275,6 +288,11 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
||||
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
||||
)
|
||||
epoch_itr.load_state_dict(itr_state)
|
||||
|
||||
# Preload the observer stats for Supernet
|
||||
supernet_cp_dict = extra_state.get("supernet", {})
|
||||
if supernet_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
|
||||
trainer.task.set_checkpoint_dict(supernet_cp_dict)
|
||||
else:
|
||||
epoch_itr = trainer.get_train_iterator(
|
||||
epoch=1, load_dataset=True, **passthrough_args
|
||||
|
@ -11,8 +11,6 @@ import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from fairseq import checkpoint_utils
|
||||
from tests.utils import (
|
||||
create_dummy_data,
|
||||
|
172
tests/test_checkpoint_utils_for_task_level_attributes.py
Normal file
172
tests/test_checkpoint_utils_for_task_level_attributes.py
Normal file
@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env fbpython
|
||||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, data
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def mock_trainer(epoch, num_updates, iterations_in_epoch):
|
||||
trainer = MagicMock()
|
||||
trainer.load_checkpoint.return_value = {
|
||||
"train_iterator": {
|
||||
"epoch": epoch,
|
||||
"iterations_in_epoch": iterations_in_epoch,
|
||||
"shuffle": False,
|
||||
},
|
||||
"supernet": checkpoint_dict()["supernet"],
|
||||
}
|
||||
trainer.get_num_updates.return_value = num_updates
|
||||
trainer.task.get_checkpoint_dict.return_value = checkpoint_dict()
|
||||
trainer.task.set_checkpoint_dict = MagicMock()
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def checkpoint_dict():
|
||||
return {
|
||||
"supernet": {
|
||||
"observer_stats": {
|
||||
(
|
||||
4,
|
||||
16,
|
||||
"MovingAveragePerChannelMinMax",
|
||||
"MovingAveragePerChannelMinMax",
|
||||
): {"mod1": 1, "mod2": 2, "mod3": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def mock_dict():
|
||||
d = MagicMock()
|
||||
d.pad.return_value = 1
|
||||
d.eos.return_value = 2
|
||||
d.unk.return_value = 3
|
||||
return d
|
||||
|
||||
|
||||
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
|
||||
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
|
||||
tokens_ds = data.TokenBlockDataset(
|
||||
tokens,
|
||||
sizes=[tokens.size(-1)],
|
||||
block_size=1,
|
||||
pad=0,
|
||||
eos=1,
|
||||
include_targets=False,
|
||||
)
|
||||
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
|
||||
dataset = data.LanguagePairDataset(
|
||||
tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False
|
||||
)
|
||||
epoch_itr = data.EpochBatchIterator(
|
||||
dataset=dataset,
|
||||
collate_fn=dataset.collater,
|
||||
batch_sampler=[[i] for i in range(epoch_size)],
|
||||
)
|
||||
return trainer, epoch_itr
|
||||
|
||||
|
||||
def get_mock_cfg(finetune_from_model):
|
||||
cfg_mock = OmegaConf.create(
|
||||
{
|
||||
"checkpoint": {
|
||||
"save_dir": None,
|
||||
"optimizer_overrides": "{}",
|
||||
"reset_dataloader": False,
|
||||
"reset_meters": False,
|
||||
"reset_optimizer": False,
|
||||
"reset_lr_scheduler": False,
|
||||
"finetune_from_model": finetune_from_model,
|
||||
"model_parallel_size": 1,
|
||||
"restore_file": "checkpoint_last.pt",
|
||||
"no_save": False,
|
||||
"save_interval_updates": 0,
|
||||
"no_last_checkpoints": False,
|
||||
"keep_interval_updates": 0,
|
||||
"keep_last_epochs": 0,
|
||||
"keep_best_checkpoints": 0,
|
||||
},
|
||||
"common": {
|
||||
"model_parallel_size": 1,
|
||||
},
|
||||
}
|
||||
)
|
||||
return cfg_mock
|
||||
|
||||
|
||||
class TestCheckpointsForTaskLevelAttributes(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.cfg_mock = get_mock_cfg(None)
|
||||
self.patches = {
|
||||
"os.makedirs": MagicMock(),
|
||||
"os.path.join": MagicMock(),
|
||||
"os.path.isfile": MagicMock(return_value=True),
|
||||
"os.path.isabs": MagicMock(return_value=False),
|
||||
"fairseq.file_io.PathManager.exists": MagicMock(return_value=False),
|
||||
}
|
||||
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
|
||||
[p.start() for p in self.applied_patches]
|
||||
logging.disable(logging.CRITICAL)
|
||||
|
||||
self.trainer, self.epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
|
||||
self.trainer.get_train_iterator = MagicMock(return_value=self.epoch_itr)
|
||||
self.epoch_itr.next_epoch_itr(shuffle=False)
|
||||
|
||||
checkpoint_utils.save_checkpoint(
|
||||
self.cfg_mock.checkpoint, self.trainer, self.epoch_itr, None
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
patch.stopall()
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
def test_verify_checkpoint(self) -> None:
|
||||
cp_dict = self.trainer.task.get_checkpoint_dict()
|
||||
self.assertTrue(len(cp_dict) == 1)
|
||||
self.assertTrue("supernet" in cp_dict)
|
||||
self.assertTrue("observer_stats" in cp_dict["supernet"])
|
||||
self.assertTrue(len(cp_dict["supernet"]["observer_stats"]) == 1)
|
||||
self.assertTrue(
|
||||
(
|
||||
4,
|
||||
16,
|
||||
"MovingAveragePerChannelMinMax",
|
||||
"MovingAveragePerChannelMinMax",
|
||||
)
|
||||
in cp_dict["supernet"]["observer_stats"]
|
||||
)
|
||||
self.assertTrue(
|
||||
cp_dict["supernet"]["observer_stats"][
|
||||
(
|
||||
4,
|
||||
16,
|
||||
"MovingAveragePerChannelMinMax",
|
||||
"MovingAveragePerChannelMinMax",
|
||||
)
|
||||
]
|
||||
== {"mod1": 1, "mod2": 2, "mod3": 3}
|
||||
)
|
||||
|
||||
def test_load_checkpoint(self) -> None:
|
||||
with contextlib.redirect_stdout(StringIO()):
|
||||
# Now, load checkpoint to ensure the respective logic works as expected
|
||||
_, epoch_itr = checkpoint_utils.load_checkpoint(
|
||||
self.cfg_mock.checkpoint, self.trainer
|
||||
)
|
||||
|
||||
self.trainer.task.set_checkpoint_dict.assert_called_once_with(
|
||||
checkpoint_dict()["supernet"]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user