Fix NAT code (#1454)

Summary:
D23752010 (add65adcc5) broke some GPU-only tests for NAT.

Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/1454

Test Plan: Imported from OSS

Reviewed By: jmp84

Differential Revision: D25108461

Pulled By: myleott

fbshipit-source-id: f32b890221578c421944d6f9a49f06ef1dc075c6
This commit is contained in:
Myle Ott 2020-11-20 12:40:49 -08:00 committed by Facebook GitHub Bot
parent fa113ff1de
commit d464af2feb
5 changed files with 51 additions and 31 deletions

View File

@ -18,18 +18,23 @@ def ensemble_encoder(func):
def wrapper(self, *args, **kwargs):
if self.ensemble_models is None or len(self.ensemble_models) == 1:
return func(self, *args, **kwargs)
encoder_outs = [func(model, *args, **kwargs) for model in self.ensemble_models]
_encoder_out = encoder_outs[0]
encoder_outs = [func(model, *args, **kwargs, return_all_hiddens=True) for model in self.ensemble_models]
_encoder_out = encoder_outs[0].copy()
def stack(key):
outs = [getattr(e, key) for e in encoder_outs]
return torch.stack(outs, -1) if outs[0] is not None else None
outs = [e[key][0] for e in encoder_outs]
return [torch.stack(outs, -1) if outs[0] is not None else None]
return _encoder_out._replace(
encoder_out=stack("encoder_out"),
encoder_embedding=stack("encoder_embedding"),
encoder_states=stack("encoder_states"),
)
_encoder_out["encoder_out"] = stack("encoder_out")
_encoder_out["encoder_embedding"] = stack("encoder_embedding")
num_layers = len(_encoder_out["encoder_states"])
if num_layers > 0:
_encoder_out["encoder_states"] = [
torch.stack([e["encoder_states"][i] for e in encoder_outs], -1)
for i in range(num_layers)
]
return _encoder_out
return wrapper
@ -41,12 +46,18 @@ def ensemble_decoder(func):
self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs
)
def _replace(encoder_out, new_val):
new_encoder_out = encoder_out.copy()
new_encoder_out["encoder_out"] = [new_val]
return new_encoder_out
action_outs = [
func(
model,
normalize=normalize,
encoder_out=encoder_out._replace(
encoder_out=encoder_out.encoder_out[:, :, :, i]
encoder_out=_replace(
encoder_out,
encoder_out["encoder_out"][0][:, :, :, i]
),
*args,
**kwargs

View File

@ -149,11 +149,11 @@ class LevenshteinTransformerModel(FairseqNATModel):
if max_ratio is None:
max_lens = torch.zeros_like(output_tokens).fill_(255)
else:
if encoder_out.encoder_padding_mask is None:
max_src_len = encoder_out.encoder_out.size(0)
src_lens = encoder_out.encoder_out.new(bsz).fill_(max_src_len)
if not encoder_out["encoder_padding_mask"]:
max_src_len = encoder_out["encoder_out"].size(0)
src_lens = encoder_out["encoder_out"].new(bsz).fill_(max_src_len)
else:
src_lens = (~encoder_out.encoder_padding_mask).sum(1)
src_lens = (~encoder_out["encoder_padding_mask"][0]).sum(1)
max_lens = (src_lens * max_ratio).clamp(min=10).long()
# delete words

View File

@ -83,14 +83,13 @@ class EnsembleLevT(BasicEnsembleModel):
if max_ratio is None:
max_lens = output_tokens.new().fill_(255)
else:
if encoder_outs[0].encoder_padding_mask is None:
if not encoder_outs[0]["encoder_padding_mask"]:
src_lens = (
encoder_outs[0]
.encoder_out.new(bsz)
.fill_(encoder_outs[0].encoder_out.size(1))
encoder_outs[0]["encoder_out"][0].new(bsz)
.fill_(encoder_outs[0]["encoder_out"][0].size(1))
)
else:
src_lens = (~encoder_outs[0].encoder_padding_mask).sum(1)
src_lens = (~encoder_outs[0]["encoder_padding_mask"][0]).sum(1)
max_lens = (src_lens * max_ratio).clamp(min=10).long()
# delete words

View File

@ -93,17 +93,25 @@ class TestTranslationGPU(unittest.TestCase):
],
task="translation_lev",
)
gen_config = [
"--task",
"translation_lev",
"--iter-decode-max-iter",
"9",
"--iter-decode-eos-penalty",
"0",
"--print-step",
]
# non-ensemble generation
generate_main(data_dir, gen_config)
# ensemble generation
generate_main(
data_dir,
[
"--task",
"translation_lev",
"--iter-decode-max-iter",
"9",
"--iter-decode-eos-penalty",
"0",
"--print-step",
],
gen_config,
path=os.pathsep.join([
os.path.join(data_dir, "checkpoint_last.pt"),
os.path.join(data_dir, "checkpoint_last.pt"),
]),
)

View File

@ -345,18 +345,20 @@ def train_translation_model(
validate.main(validate_args)
def generate_main(data_dir, extra_flags=None):
def generate_main(data_dir, extra_flags=None, path=None):
if extra_flags is None:
extra_flags = [
"--print-alignment",
]
if path is None:
path = os.path.join(data_dir, "checkpoint_last.pt")
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
[
data_dir,
"--path",
os.path.join(data_dir, "checkpoint_last.pt"),
path,
"--beam",
"3",
"--batch-size",