diff --git a/fairseq/optim/bmuf.py b/fairseq/optim/bmuf.py index c5a588803..2f824b19a 100644 --- a/fairseq/optim/bmuf.py +++ b/fairseq/optim/bmuf.py @@ -89,6 +89,7 @@ class FairseqBMUF(FairseqOptimizer): def load_state_dict(self, state_dict, optimizer_overrides=None): self._optimizer.load_state_dict(state_dict, optimizer_overrides) + self.initial_state = self._optimizer.state_dict() def multiply_grads(self, c): """Multiplies grads by a constant *c*."""