Fix FSDP optim state loading (#1819)

Summary:
### Problem:
- if we consolidate optim state dict on rank 0, rank 1+ save `optimizer.state_dict()`. When they try to load, they call get_shard(last_optim_state), which is wrong since the optim state is already shared. They should find the global consolidated optimizer state dict and load that.

### Possible Solutions:
- if world size is the same, you could just reuse the local OSD.
- [this PR] rank 1+ load optim state from the rank0 file and call get_shard
- separate file for optim_state that every rank loads. (like 'shared.pt' on `gshard-azure`). This will save some CPU Ram.

### Note:
- I don't think it's possible to pass `--use-sharded-state` from the command line. It should be I think.

### Implementation here
+ if FSDP saves -1 as state['last_optimizer_key'], it means that, on load, rank 0's optim state must be loaded.
+ regression test

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1819

Reviewed By: zhengwy888

Differential Revision: D27910281

Pulled By: sshleifer

fbshipit-source-id: d34987008f77ce7e0cb28b7224dd2aabed38a70c
This commit is contained in:
Sam Shleifer 2021-04-21 15:49:03 -07:00 committed by Facebook GitHub Bot
parent 207254bf56
commit 05b86005bc
2 changed files with 21 additions and 14 deletions

View File

@ -26,7 +26,7 @@ from fairseq.nan_detector import NanDetector
from fairseq.optim import lr_scheduler
from omegaconf import OmegaConf
import re
logger = logging.getLogger(__name__)
@ -331,14 +331,17 @@ class Trainer(object):
def consolidate_optimizer(self):
"""For OSS, we need to consolidate the state dict."""
if self.cfg.checkpoint.no_save_optimizer_state:
return
self._gathered_optim_state = None
if hasattr(self.optimizer.optimizer, "consolidate_state_dict"):
self.optimizer.optimizer.consolidate_state_dict()
elif self.cfg.distributed_training.ddp_backend == 'fully_sharded':
self._gathered_optim_state = self.model.gather_full_optim_state_dict(self.optimizer,
recipient_rank=0)
elif self.cfg.distributed_training.ddp_backend == 'fully_sharded' and not self.model.use_sharded_state:
st = self.model.gather_full_optim_state_dict(self.optimizer) # only returns on rank 0
if st is None:
st = -1 # sentinel so that workers do not save optimizer.state_dict()
self._gathered_optim_state = st
def state_dict(self):
state_dict = {
@ -423,6 +426,9 @@ class Trainer(object):
filename, load_on_all_ranks=load_on_all_ranks
)
last_optim_state = state.get("last_optimizer_state", None)
if last_optim_state == -1:
master_path = re.sub("shard[0-9]+", "shard0", filename)
last_optim_state = torch.load(master_path, map_location='cpu')['last_optimizer_state']
# If doing zero_sharding, do not broadcast global optimizer
# state. Later we will broadcast sharded states to each rank

View File

@ -56,7 +56,13 @@ class TestTranslationGPU(unittest.TestCase):
continue
return logs
def test_resume_training(self):
def test_resume_training_fsdp(self):
self._test_resume_training(["--ddp-backend", "fully_sharded"])
def test_resume_training_noc10d(self):
self._test_resume_training([])
def _test_resume_training(self, extra_clargs, arch="fconv_iwslt_de_en"):
flags = [
"--fp16",
"--log-format",
@ -67,27 +73,22 @@ class TestTranslationGPU(unittest.TestCase):
"2",
"--log-interval",
"1",
"--log-file",
]
] + extra_clargs
world_size = min(torch.cuda.device_count(), 2)
arch = "fconv_iwslt_de_en"
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_fp16") as data_dir:
log = os.path.join(data_dir, "train.log")
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(
data_dir, arch, flags + [log], world_size=world_size,
data_dir, arch, flags + ["--log-file", log], world_size=world_size,
)
log2 = os.path.join(data_dir, "resume.log")
restore_file = os.path.join(data_dir, "checkpoint_1_2.pt")
assert os.path.exists(
restore_file
), f"{restore_file} not written. Choices: {os.listdir(data_dir)}"
train_translation_model(
data_dir,
arch,
flags + [log2, "--restore-file", restore_file],
flags + ["--log-file", log2, "--restore-file", restore_file],
world_size=world_size,
)