fix bmuf restart causing lr to be reset

Summary:
From https://fb.workplace.com/groups/332923290658088/permalink/505568510060231/, we have observed that learning rate on resume is reset to original value. Based on investigation in https://fb.workplace.com/groups/332923290658088/permalink/505568510060231/?comment_id=506348519982230, it seems like this happens 500 iterations after restarting, which coincides with BMUF warmup. After further debugging, ti seems like this is confirmed to be due to bmuf and because after warmup we reset the optimizer to the initial state  when created.

My proposed fix is to reset the initial state of the optimizer whenever we load the state dict.

Reviewed By: zhengwy888

Differential Revision: D19183595

fbshipit-source-id: 4cdc13378817a7e9a6b658010b152a508991971f
This commit is contained in:
Alex Xiao 2019-12-20 15:25:51 -08:00 committed by Facebook Github Bot
parent 58ec8c0f28
commit 9ad6b5a967

View File

@ -89,6 +89,7 @@ class FairseqBMUF(FairseqOptimizer):
def load_state_dict(self, state_dict, optimizer_overrides=None):
self._optimizer.load_state_dict(state_dict, optimizer_overrides)
self.initial_state = self._optimizer.state_dict()
def multiply_grads(self, c):
"""Multiplies grads by a constant *c*."""