XLM for NMT: option to only load encoder or decoder (#666)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/666

Option to load the XLM weights into only the encoder or the decoder

Reviewed By: pipibjc

Differential Revision: D14881004

fbshipit-source-id: 6d0d598ea9c445ec468f71b8e855712de89a5dac
This commit is contained in:
Liezl Puzon 2019-04-25 05:52:36 -07:00 committed by Facebook Github Bot
parent 8da9b1c530
commit 5008fd4e5a
2 changed files with 69 additions and 1 deletions

View File

@ -32,6 +32,16 @@ class TransformerFromPretrainedXLMModel(TransformerModel):
metavar="STR",
help="XLM model to use for initializing transformer encoder and/or decoder",
)
parser.add_argument(
"--init-encoder-only",
action="store_true",
help="if set, don't load the XLM weights and embeddings into decoder",
)
parser.add_argument(
"--init-decoder-only",
action="store_true",
help="if set, don't load the XLM weights and embeddings into encoder",
)
@classmethod
def build_model(cls, args, task):
@ -48,7 +58,10 @@ class TransformerFromPretrainedXLMModel(TransformerModel):
"For translation, you may want to use --task "
"translation_from_pretrained_xlm"
)
assert not (
getattr(args, "init_encoder_only", False)
and getattr(args, "init_decoder_only", False)
), "Only one of --init-encoder-only and --init-decoder-only can be set."
return super().build_model(args, task)
@classmethod
@ -100,6 +113,10 @@ def upgrade_state_dict_with_xlm_weights(
class TransformerEncoderFromPretrainedXLM(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
if getattr(args, 'init_decoder_only', False):
# Don't load XLM weights for encoder if --init-decoder-only
return
assert hasattr(args, "pretrained_xlm_checkpoint"), (
"--pretrained-xlm-checkpoint must be specified to load Transformer "
"encoder from pretrained XLM"
@ -118,6 +135,9 @@ class TransformerDecoderFromPretrainedXLM(TransformerDecoder):
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn, final_norm
)
if getattr(args, 'init_encoder_only', False):
# Don't load XLM weights for decoder if --init-encoder-only
return
assert hasattr(args, "pretrained_xlm_checkpoint"), (
"--pretrained-xlm-checkpoint must be specified to load Transformer "
"decoder from pretrained XLM"

View File

@ -276,6 +276,54 @@ class TestMaskedLanguageModel(unittest.TestCase):
task="translation_from_pretrained_xlm",
)
def test_pretrained_masked_lm_for_translation_encoder_only(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, arch="xlm_base")
with tempfile.TemporaryDirectory(
"test_mlm_translation"
) as translation_dir:
create_dummy_data(translation_dir)
preprocess_translation_data(
translation_dir, extra_flags=["--joined-dictionary"]
)
# Train transformer with data_dir/checkpoint_last.pt
train_translation_model(
translation_dir,
arch="transformer_from_pretrained_xlm",
extra_flags=[
"--decoder-layers",
"1",
"--decoder-embed-dim",
"32",
"--decoder-attention-heads",
"1",
"--decoder-ffn-embed-dim",
"32",
"--encoder-layers",
"1",
"--encoder-embed-dim",
"32",
"--encoder-attention-heads",
"1",
"--encoder-ffn-embed-dim",
"32",
"--pretrained-xlm-checkpoint",
f"{data_dir}/checkpoint_last.pt",
"--encoder-learned-pos",
"--decoder-learned-pos",
"--activation-fn",
"gelu",
"--max-source-positions",
"500",
"--max-target-positions",
"500",
"--init-encoder-only",
],
task="translation_from_pretrained_xlm",
)
def train_masked_language_model(data_dir, arch):
train_parser = options.get_training_parser()