Load a XLM model into transformer encoder / decoder for MT training (#629)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/629

Use GeLU as an alternate activation layer for ReLU.

Reviewed By: lematt1991

Differential Revision: D14689851

fbshipit-source-id: 7ec81fa34bc7bd0e1e43b337847ae932dcbf8b15
This commit is contained in:
Liezl Puzon 2019-04-25 05:52:36 -07:00 committed by Facebook Github Bot
parent 8500bdd0c8
commit 8da9b1c530
3 changed files with 292 additions and 2 deletions

View File

@ -0,0 +1,137 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from typing import Any, Dict
from fairseq import utils
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.models.transformer import (
TransformerDecoder,
TransformerEncoder,
TransformerModel,
base_architecture as transformer_base_architecture,
)
from . import register_model, register_model_architecture
@register_model("transformer_from_pretrained_xlm")
class TransformerFromPretrainedXLMModel(TransformerModel):
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
TransformerModel.add_args(parser)
parser.add_argument(
"--pretrained-xlm-checkpoint",
type=str,
metavar="STR",
help="XLM model to use for initializing transformer encoder and/or decoder",
)
@classmethod
def build_model(cls, args, task):
assert hasattr(args, "pretrained_xlm_checkpoint"), (
"You must specify a path for --pretrained-xlm-checkpoint to use "
"--arch transformer_from_pretrained_xlm"
)
assert isinstance(task.source_dictionary, MaskedLMDictionary) and isinstance(
task.target_dictionary, MaskedLMDictionary
), (
"You should use a MaskedLMDictionary when using --arch "
"transformer_from_pretrained_xlm because the pretrained XLM model "
"was trained using data binarized with MaskedLMDictionary. "
"For translation, you may want to use --task "
"translation_from_pretrained_xlm"
)
return super().build_model(args, task)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return TransformerEncoderFromPretrainedXLM(args, src_dict, embed_tokens)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoderFromPretrainedXLM(args, tgt_dict, embed_tokens)
def upgrade_state_dict_with_xlm_weights(
state_dict: Dict[str, Any], pretrained_xlm_checkpoint: str
) -> Dict[str, Any]:
"""
Load XLM weights into a Transformer encoder or decoder model.
Args:
state_dict: state dict for either TransformerEncoder or
TransformerDecoder
pretrained_xlm_checkpoint: checkpoint to load XLM weights from
Raises:
AssertionError: If architecture (num layers, attention heads, etc.)
does not match between the current Transformer encoder or
decoder and the pretrained_xlm_checkpoint
"""
if not os.path.exists(pretrained_xlm_checkpoint):
raise IOError(f"Model file not found: {pretrained_xlm_checkpoint}")
state = utils.load_checkpoint_to_cpu(pretrained_xlm_checkpoint)
xlm_state_dict = state["model"]
for key in xlm_state_dict.keys():
for search_key in ["embed_tokens", "embed_positions", "layers"]:
if search_key in key:
subkey = key[key.find(search_key):]
assert subkey in state_dict, (
f"{str(state_dict.keys())} Transformer encoder / decoder "
f"state_dict does not contain {subkey}. Cannot "
f"load {key} from pretrained XLM checkpoint "
f"{pretrained_xlm_checkpoint} into Transformer."
)
state_dict[subkey] = xlm_state_dict[key]
return state_dict
class TransformerEncoderFromPretrainedXLM(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens)
assert hasattr(args, "pretrained_xlm_checkpoint"), (
"--pretrained-xlm-checkpoint must be specified to load Transformer "
"encoder from pretrained XLM"
)
xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights(
state_dict=self.state_dict(),
pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint,
)
self.load_state_dict(xlm_loaded_state_dict, strict=True)
class TransformerDecoderFromPretrainedXLM(TransformerDecoder):
def __init__(
self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True
):
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn, final_norm
)
assert hasattr(args, "pretrained_xlm_checkpoint"), (
"--pretrained-xlm-checkpoint must be specified to load Transformer "
"decoder from pretrained XLM"
)
xlm_loaded_state_dict = upgrade_state_dict_with_xlm_weights(
state_dict=self.state_dict(),
pretrained_xlm_checkpoint=args.pretrained_xlm_checkpoint,
)
self.load_state_dict(xlm_loaded_state_dict, strict=True)
@register_model_architecture(
"transformer_from_pretrained_xlm", "transformer_from_pretrained_xlm"
)
def base_architecture(args):
transformer_base_architecture(args)

View File

@ -0,0 +1,33 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.tasks.translation import TranslationTask
from . import register_task
@register_task("translation_from_pretrained_xlm")
class TranslationFromPretrainedXLMTask(TranslationTask):
"""
Same as TranslationTask except use the MaskedLMDictionary class so that
we can load data that was binarized with the MaskedLMDictionary class.
This task should be used for the entire training pipeline when we want to
train an NMT model from a pretrained XLM checkpoint: binarizing NMT data,
training NMT with the pretrained XLM checkpoint, and subsequent evaluation
of that trained model.
"""
@classmethod
def load_dictionary(cls, filename):
"""Load the masked LM dictionary from the filename
Args:
filename (str): the filename
"""
return MaskedLMDictionary.load(filename)

View File

@ -220,6 +220,126 @@ class TestLanguageModeling(unittest.TestCase):
eval_lm_main(data_dir)
class TestMaskedLanguageModel(unittest.TestCase):
def test_masked_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, "xlm_base")
def test_pretrained_masked_lm_for_translation(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory("test_mlm") as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_masked_language_model(data_dir, arch="xlm_base")
with tempfile.TemporaryDirectory(
"test_mlm_translation"
) as translation_dir:
create_dummy_data(translation_dir)
preprocess_translation_data(
translation_dir, extra_flags=["--joined-dictionary"]
)
# Train transformer with data_dir/checkpoint_last.pt
train_translation_model(
translation_dir,
arch="transformer_from_pretrained_xlm",
extra_flags=[
"--decoder-layers",
"1",
"--decoder-embed-dim",
"32",
"--decoder-attention-heads",
"1",
"--decoder-ffn-embed-dim",
"32",
"--encoder-layers",
"1",
"--encoder-embed-dim",
"32",
"--encoder-attention-heads",
"1",
"--encoder-ffn-embed-dim",
"32",
"--pretrained-xlm-checkpoint",
f"{data_dir}/checkpoint_last.pt",
"--encoder-learned-pos",
"--decoder-learned-pos",
"--activation-fn",
"gelu",
"--max-source-positions",
"500",
"--max-target-positions",
"500",
],
task="translation_from_pretrained_xlm",
)
def train_masked_language_model(data_dir, arch):
train_parser = options.get_training_parser()
# TODO: langs should be in and out right?
train_args = options.parse_args_and_arch(
train_parser,
[
"--task",
"cross_lingual_lm",
data_dir,
"--arch",
arch,
# Optimizer args
"--optimizer",
"adam",
"--lr-scheduler",
"reduce_lr_on_plateau",
"--lr-shrink",
"0.5",
"--lr",
"0.0001",
"--min-lr",
"1e-09",
# dropout, attention args
"--dropout",
"0.1",
"--no-bias-kv",
"--attention-dropout",
"0.1",
# MLM args
"--criterion",
"masked_lm_loss",
"--masked-lm-only",
"--monolingual-langs",
"in,out",
"--num-segment",
"5",
# Transformer args: use a small transformer model for fast training
"--encoder-layers",
"1",
"--encoder-embed-dim",
"32",
"--encoder-attention-heads",
"1",
"--encoder-ffn-embed-dim",
"32",
# Other training args
"--max-tokens",
"500",
"--tokens-per-sample",
"500",
"--save-dir",
data_dir,
"--max-epoch",
"1",
"--no-progress-bar",
"--distributed-world-size",
"1",
"--raw-text",
],
)
train.main(train_args)
class TestCommonOptions(unittest.TestCase):
def test_optimizers(self):
@ -281,12 +401,12 @@ def preprocess_translation_data(data_dir, extra_flags=None):
preprocess.main(preprocess_args)
def train_translation_model(data_dir, arch, extra_flags=None):
def train_translation_model(data_dir, arch, extra_flags=None, task='translation'):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
'--task', 'translation',
'--task', task,
data_dir,
'--save-dir', data_dir,
'--arch', arch,