Make error message for trying to train after make_generation_fast work correctly

Summary: https://github.com/pytorch/fairseq/blob/master/fairseq/trainer.py#L164 calls `train()` without any argument

Reviewed By: myleott

Differential Revision: D13599203

fbshipit-source-id: 3a096a6dd35a7a3f8309fbda3b54a36f606475e3
This commit is contained in:
Wei Ho 2019-01-09 16:01:17 -08:00 committed by Facebook Github Bot
parent 4b1f4788d8
commit 315fa5cbd9

View File

@ -117,7 +117,7 @@ class BaseFairseqModel(nn.Module):
self.apply(apply_make_generation_fast_)
def train(mode):
def train(mode=True):
if mode:
raise RuntimeError('cannot train after make_generation_fast')