support use_sharded_state on command line

Summary:
we wanted to use sharded_state because
1. to save memory
2. support sharded state loading, which allows MoE models's weight to live on their respective shard
I just added the use_sharded_state as a config option, and added unit test to make sure it runs fine.

old revision's comment:
fairseq.FSDP has a  flag use_sharded_state, but I had to address a couple problems before being able to use it.
1. fairscale FSDP (FSDP for short) calls self.state_dict/load_state_dict, which has been overwritten by fairseq.FSDP, this is not a desired behavior
2. the optimizer states shouldn't be sharded again when use_sharded_state is True
3. expose this option on the command line.

Reviewed By: sshleifer

Differential Revision: D28375035

fbshipit-source-id: c2f59a9c62163405033f34ed595ba78528aea850
This commit is contained in:
Weiyi Zheng 2021-05-14 18:52:06 -07:00 committed by Facebook GitHub Bot
parent d151f27872
commit 425c36eaff
4 changed files with 10 additions and 3 deletions

View File

@ -379,6 +379,9 @@ class DistributedTrainingConfig(FairseqDataclass):
cpu_offload: bool = field(
default=False, metadata={"help": "offload FP32 params to CPU"}
)
use_sharded_state: bool = field(
default=False, metadata={"help": "use sharded checkpoint files"},
)
@dataclass

View File

@ -77,7 +77,7 @@ class FullyShardedDataParallel(FSDP):
@contextlib.contextmanager
def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False):
def fsdp_enable_wrap(cfg: DistributedTrainingConfig):
try:
from fairscale.nn import enable_wrap
except ImportError:
@ -105,7 +105,7 @@ def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = F
}
with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
use_sharded_state=use_sharded_state,
use_sharded_state=cfg.use_sharded_state,
**fsdp_config,
):
yield

View File

@ -496,7 +496,8 @@ class Trainer(object):
last_optim_state = self.optimizer.broadcast_global_state_dict(
last_optim_state
)
elif self.cfg.distributed_training.ddp_backend == 'fully_sharded':
elif self.cfg.distributed_training.ddp_backend == 'fully_sharded' and not self.model.use_sharded_state:
# if use_sharded_state, the last_optim_state is already sharded, skip this
last_optim_state = self.model.get_shard_from_optim_state_dict(last_optim_state)
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)

View File

@ -59,6 +59,9 @@ class TestTranslationGPU(unittest.TestCase):
def test_resume_training_fsdp(self):
self._test_resume_training(["--ddp-backend", "fully_sharded"])
def test_resume_training_fsdp_sharded_state(self):
self._test_resume_training(["--ddp-backend", "fully_sharded", "--use-sharded-state"])
def test_resume_training_noc10d(self):
self._test_resume_training([])