mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-10-26 17:32:57 +03:00
support use_sharded_state on command line
Summary: we wanted to use sharded_state because 1. to save memory 2. support sharded state loading, which allows MoE models's weight to live on their respective shard I just added the use_sharded_state as a config option, and added unit test to make sure it runs fine. old revision's comment: fairseq.FSDP has a flag use_sharded_state, but I had to address a couple problems before being able to use it. 1. fairscale FSDP (FSDP for short) calls self.state_dict/load_state_dict, which has been overwritten by fairseq.FSDP, this is not a desired behavior 2. the optimizer states shouldn't be sharded again when use_sharded_state is True 3. expose this option on the command line. Reviewed By: sshleifer Differential Revision: D28375035 fbshipit-source-id: c2f59a9c62163405033f34ed595ba78528aea850
This commit is contained in:
parent
d151f27872
commit
425c36eaff
@ -379,6 +379,9 @@ class DistributedTrainingConfig(FairseqDataclass):
|
||||
cpu_offload: bool = field(
|
||||
default=False, metadata={"help": "offload FP32 params to CPU"}
|
||||
)
|
||||
use_sharded_state: bool = field(
|
||||
default=False, metadata={"help": "use sharded checkpoint files"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -77,7 +77,7 @@ class FullyShardedDataParallel(FSDP):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False):
|
||||
def fsdp_enable_wrap(cfg: DistributedTrainingConfig):
|
||||
try:
|
||||
from fairscale.nn import enable_wrap
|
||||
except ImportError:
|
||||
@ -105,7 +105,7 @@ def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = F
|
||||
}
|
||||
with enable_wrap(
|
||||
wrapper_cls=FullyShardedDataParallel,
|
||||
use_sharded_state=use_sharded_state,
|
||||
use_sharded_state=cfg.use_sharded_state,
|
||||
**fsdp_config,
|
||||
):
|
||||
yield
|
||||
|
@ -496,7 +496,8 @@ class Trainer(object):
|
||||
last_optim_state = self.optimizer.broadcast_global_state_dict(
|
||||
last_optim_state
|
||||
)
|
||||
elif self.cfg.distributed_training.ddp_backend == 'fully_sharded':
|
||||
elif self.cfg.distributed_training.ddp_backend == 'fully_sharded' and not self.model.use_sharded_state:
|
||||
# if use_sharded_state, the last_optim_state is already sharded, skip this
|
||||
last_optim_state = self.model.get_shard_from_optim_state_dict(last_optim_state)
|
||||
|
||||
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
|
||||
|
@ -59,6 +59,9 @@ class TestTranslationGPU(unittest.TestCase):
|
||||
def test_resume_training_fsdp(self):
|
||||
self._test_resume_training(["--ddp-backend", "fully_sharded"])
|
||||
|
||||
def test_resume_training_fsdp_sharded_state(self):
|
||||
self._test_resume_training(["--ddp-backend", "fully_sharded", "--use-sharded-state"])
|
||||
|
||||
def test_resume_training_noc10d(self):
|
||||
self._test_resume_training([])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user