FSDP uses new optimizer gathering to save optimizer state (#1744)

Summary:
- Full unflattened optimizer state dict is in `checkpoints/shard_0.pt`, other checkpoint files do not have the `last_optimizer_state` key.
- requires master version of fairscale (eventually fairscale>=0.3.3)

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1744

Reviewed By: myleott

Differential Revision: D27342305

Pulled By: sshleifer

fbshipit-source-id: 7442b8c6ed01599d8ab0050213e84051f4e98acd
This commit is contained in:
Sam Shleifer 2021-03-26 07:18:08 -07:00 committed by Facebook GitHub Bot
parent a28511d43d
commit be1d186fa5
4 changed files with 32 additions and 9 deletions

View File

@ -42,7 +42,7 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
if cfg.no_save:
return
trainer.consolidate_optimizer()
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
if not trainer.should_save_checkpoint_on_current_rank:
return

View File

@ -312,13 +312,12 @@ def batch_by_size(
)
except ImportError:
raise ImportError(
"Please build Cython components with: `pip install --editable .` "
"or `python setup.py build_ext --inplace`"
"Please build Cython components with: "
"`python setup.py build_ext --inplace`"
)
except ValueError:
raise ValueError(
"Please build (or rebuild) Cython components with: `pip install "
" --editable .` or `python setup.py build_ext --inplace`."
"Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
)
# added int() to avoid TypeError: an integer is required

View File

@ -331,9 +331,15 @@ class Trainer(object):
def consolidate_optimizer(self):
"""For OSS, we need to consolidate the state dict."""
self._gathered_optim_state = None
if hasattr(self.optimizer.optimizer, "consolidate_state_dict"):
self.optimizer.optimizer.consolidate_state_dict()
elif self.cfg.distributed_training.ddp_backend == 'fully_sharded':
self._gathered_optim_state = self.model.gather_full_optim_state_dict(self.optimizer,
recipient_rank=0)
def state_dict(self):
state_dict = {
"args": None, # legacy
@ -362,7 +368,11 @@ class Trainer(object):
}
}
if not self.cfg.checkpoint.no_save_optimizer_state:
state_dict["last_optimizer_state"] = self.optimizer.state_dict()
if self._gathered_optim_state is not None:
state_dict["last_optimizer_state"] = self._gathered_optim_state
self._gathered_optim_state = None
else:
state_dict["last_optimizer_state"] = self.optimizer.state_dict()
return state_dict
def save_checkpoint(self, filename, extra_state):
@ -478,6 +488,9 @@ class Trainer(object):
last_optim_state = self.optimizer.broadcast_global_state_dict(
last_optim_state
)
elif self.cfg.distributed_training.ddp_backend == 'fully_sharded':
last_optim_state = self.model.get_shard_from_optim_state_dict(last_optim_state)
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self.set_num_updates(last_optim["num_updates"])

View File

@ -1,13 +1,24 @@
#!/usr/bin/env bash
rm -rf fsdp_dummy
mkdir -p fsdp_dummy
fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
--ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
--cpu-offload --checkpoint-activations \
--task language_modeling --tokens-per-sample 256 --batch-size 8 \
--arch transformer_lm_gpt2_tiny \
--optimizer cpu_adam --adam-betas "(0.9,0.98)" \
--lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
--max-update 10 --log-format json --log-interval 1 \
--save-interval-updates 10 --save-dir fsdp_dummy \
--max-update 5 --log-format json --log-interval 1 \
--save-interval-updates 5 --save-dir fsdp_dummy --disable-validation \
--restore-file x.pt "$@"
# Now we try to load the checkpoint
CUDA_VISIBLE_DEVICES=0,1 fairseq-train /private/home/sshleifer/data-bin/stories_mmap \
--ddp-backend fully_sharded --fp16 --fp16-init-scale 4 \
--cpu-offload --checkpoint-activations \
--task language_modeling --tokens-per-sample 256 --batch-size 8 \
--arch transformer_lm_gpt2_tiny \
--optimizer cpu_adam --adam-betas "(0.9,0.98)" \
--lr 0.0001 --lr-scheduler polynomial_decay --warmup-updates 5 --total-num-update 10 \
--max-update 2 --log-format json --log-interval 1 \
--save-interval-updates 2 --save-dir fsdp_dummy