From e23e5eaa321e2aa18330f1d13b70a9851e125901 Mon Sep 17 00:00:00 2001 From: ngoyal2707 Date: Tue, 5 Nov 2019 15:01:26 -0800 Subject: [PATCH] XLM-R code and model release (#900) Summary: TODO: 1) Need to update bibtex entry 2) Need to upload models, spm_vocab and dict.txt to public s3 location. For Future: 1) I will probably add instructions to finetune on XNLI and NER, POS etc. but currently no timeline for that. Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/900 Reviewed By: myleott Differential Revision: D18333076 Pulled By: myleott fbshipit-source-id: 3f3d3716fcc41c78d2dd4525f60b519abbd0459c --- README.md | 1 + examples/roberta/README.md | 1 + examples/xlmr/README.md | 77 +++++++++++++++++++++++++++++++++ fairseq/models/roberta/model.py | 24 ++++++++++ 4 files changed, 103 insertions(+) create mode 100644 examples/xlmr/README.md diff --git a/README.md b/README.md index 2d6705627..92d7edfc4 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ modeling and other text generation tasks. ### What's New: +- November 2019: [XLM-R models and code released](examples/xlmr/README.md) - September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) - August 2019: [WMT'19 models released](examples/wmt19/README.md) - July 2019: fairseq relicensed under MIT license diff --git a/examples/roberta/README.md b/examples/roberta/README.md index 1661b604f..15844d3e4 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -8,6 +8,7 @@ RoBERTa iterates on BERT's pretraining procedure, including training the model l ### What's New: +- November 2019: Multilingual encoder (XLM-RoBERTa) is available [XLM-R](https://github.com/pytorch/fairseq/examples/xlmr). - September 2019: TensorFlow and TPU support via the [transformers library](https://github.com/huggingface/transformers). - August 2019: RoBERTa is now supported in the [pytorch-transformers library](https://github.com/huggingface/pytorch-transformers). - August 2019: Added [tutorial for finetuning on WinoGrande](https://github.com/pytorch/fairseq/tree/master/examples/roberta/wsc#roberta-training-on-winogrande-dataset). diff --git a/examples/xlmr/README.md b/examples/xlmr/README.md new file mode 100644 index 000000000..6aaeb5072 --- /dev/null +++ b/examples/xlmr/README.md @@ -0,0 +1,77 @@ +# Unsupervised Cross-lingual Representation Learning at Scale (XLM-RoBERTa) + +## Introduction + +XLM-R (XLM-RoBERTa) is scaled cross lingual sentence encoder. It is trained on `2.5T` of data across `100` languages data filtered from Common Crawl. XLM-R achieves state-of-the-arts results on multiple cross lingual benchmarks. + +## Pre-trained models + +Model | Description | # params | Download +---|---|---|--- +`xlmr.base.v0` | XLM-R using the BERT-base architecture | 250M | [xlm.base.v0.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz) +`xlmr.large.v0` | XLM-R using the BERT-large architecture | 560M | [xlm.large.v0.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz) + +(Note: The above models are still under training, we will update the weights, once fully trained, the results are based on the above checkpoints.) + +## Results + +**[XNLI (Conneau et al., 2018)](https://arxiv.org/abs/1809.05053)** + +Model | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur +---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|--- +`roberta.large.mnli` _(TRANSLATE-TEST)_ | 91.3 | 82.9 | 84.3 | 81.24 | 81.74 | 83.13 | 78.28 | 76.79 | 76.64 | 74.17 | 74.05 | 77.5 | 70.9 | 66.65 | 66.81 +`xlmr.large.v0` _(TRANSLATE-TRAIN-ALL)_ | 88.7 | 85.2 | 85.6 | 84.6 | 83.6 | 85.5 | 82.4 | 81.6 | 80.9 | 83.4 | 80.9 | 83.3 | 79.8 | 75.9 | 74.3 + +## Example usage + +##### Load XLM-R from torch.hub (PyTorch >= 1.1): +```python +import torch +xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.large.v0') +xlmr.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Load XLM-R (for PyTorch 1.0 or custom models): +```python +# Download xlmr.large model +wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz +tar -xzvf xlmr.large.v0.tar.gz + +# Load the model in fairseq +from fairseq.models.roberta import XLMRModel +xlmr = XLMRModel.from_pretrained('/path/to/xlmr.large.v0', checkpoint_file='model.pt') +xlmr.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Apply Byte-Pair Encoding (BPE) to input text: +```python +tokens = xlmr.encode('Hello world!') +assert tokens.tolist() == [ 0, 35378, 8999, 38, 2] +xlmr.decode(tokens) # 'Hello world!' +``` + +##### Extract features from XLM-R: +```python +# Extract the last layer's features +last_layer_features = xlmr.extract_features(tokens) +assert last_layer_features.size() == torch.Size([1, 5, 1024]) + +# Extract all layer's features (layer 0 is the embedding layer) +all_layers = xlmr.extract_features(tokens, return_all_hiddens=True) +assert len(all_layers) == 25 +assert torch.all(all_layers[-1] == last_layer_features) +``` + +## Citation + +```bibtex +@article{, + title = {Unsupervised Cross-lingual Representation Learning at Scale}, + author = {Alexis Conneau and Kartikay Khandelwal and Naman Goyal + and Vishrav Chaudhary and Guillaume Wenzek and Francisco Guzm\'an + and Edouard Grave and Myle Ott and Luke Zettlemoyer and Veselin Stoyanov + }, + journal={}, + year = {2019}, +} +``` diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index a6ff42d3e..0ce605932 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -194,6 +194,30 @@ class RobertaModel(FairseqLanguageModel): state_dict[prefix + 'classification_heads.' + k] = v +@register_model('xlmr') +class XLMRModel(RobertaModel): + @classmethod + def hub_models(cls): + return { + 'xlmr.base.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz', + 'xlmr.large.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz', + } + + @classmethod + def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs): + from fairseq import hub_utils + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + bpe=bpe, + load_checkpoint_heads=True, + **kwargs, + ) + return RobertaHubInterface(x['args'], x['task'], x['models'][0]) + + class RobertaLMHead(nn.Module): """Head for masked language modeling."""