XLM-R code and model release (#900)

Summary:
TODO:
1) Need to update bibtex entry
2) Need to upload models, spm_vocab and dict.txt to public s3 location.

For Future:

1) I will probably add instructions to finetune on XNLI and NER, POS etc. but currently no timeline for that.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/900

Reviewed By: myleott

Differential Revision: D18333076

Pulled By: myleott

fbshipit-source-id: 3f3d3716fcc41c78d2dd4525f60b519abbd0459c
This commit is contained in:
ngoyal2707 2019-11-05 15:01:26 -08:00 committed by Facebook Github Bot
parent 68dd3e171b
commit e23e5eaa32
4 changed files with 103 additions and 0 deletions

View File

@ -6,6 +6,7 @@ modeling and other text generation tasks.
### What's New:
- November 2019: [XLM-R models and code released](examples/xlmr/README.md)
- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
- August 2019: [WMT'19 models released](examples/wmt19/README.md)
- July 2019: fairseq relicensed under MIT license

View File

@ -8,6 +8,7 @@ RoBERTa iterates on BERT's pretraining procedure, including training the model l
### What's New:
- November 2019: Multilingual encoder (XLM-RoBERTa) is available [XLM-R](https://github.com/pytorch/fairseq/examples/xlmr).
- September 2019: TensorFlow and TPU support via the [transformers library](https://github.com/huggingface/transformers).
- August 2019: RoBERTa is now supported in the [pytorch-transformers library](https://github.com/huggingface/pytorch-transformers).
- August 2019: Added [tutorial for finetuning on WinoGrande](https://github.com/pytorch/fairseq/tree/master/examples/roberta/wsc#roberta-training-on-winogrande-dataset).

77
examples/xlmr/README.md Normal file
View File

@ -0,0 +1,77 @@
# Unsupervised Cross-lingual Representation Learning at Scale (XLM-RoBERTa)
## Introduction
XLM-R (XLM-RoBERTa) is scaled cross lingual sentence encoder. It is trained on `2.5T` of data across `100` languages data filtered from Common Crawl. XLM-R achieves state-of-the-arts results on multiple cross lingual benchmarks.
## Pre-trained models
Model | Description | # params | Download
---|---|---|---
`xlmr.base.v0` | XLM-R using the BERT-base architecture | 250M | [xlm.base.v0.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz)
`xlmr.large.v0` | XLM-R using the BERT-large architecture | 560M | [xlm.large.v0.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz)
(Note: The above models are still under training, we will update the weights, once fully trained, the results are based on the above checkpoints.)
## Results
**[XNLI (Conneau et al., 2018)](https://arxiv.org/abs/1809.05053)**
Model | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---
`roberta.large.mnli` _(TRANSLATE-TEST)_ | 91.3 | 82.9 | 84.3 | 81.24 | 81.74 | 83.13 | 78.28 | 76.79 | 76.64 | 74.17 | 74.05 | 77.5 | 70.9 | 66.65 | 66.81
`xlmr.large.v0` _(TRANSLATE-TRAIN-ALL)_ | 88.7 | 85.2 | 85.6 | 84.6 | 83.6 | 85.5 | 82.4 | 81.6 | 80.9 | 83.4 | 80.9 | 83.3 | 79.8 | 75.9 | 74.3
## Example usage
##### Load XLM-R from torch.hub (PyTorch >= 1.1):
```python
import torch
xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.large.v0')
xlmr.eval() # disable dropout (or leave in train mode to finetune)
```
##### Load XLM-R (for PyTorch 1.0 or custom models):
```python
# Download xlmr.large model
wget https://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz
tar -xzvf xlmr.large.v0.tar.gz
# Load the model in fairseq
from fairseq.models.roberta import XLMRModel
xlmr = XLMRModel.from_pretrained('/path/to/xlmr.large.v0', checkpoint_file='model.pt')
xlmr.eval() # disable dropout (or leave in train mode to finetune)
```
##### Apply Byte-Pair Encoding (BPE) to input text:
```python
tokens = xlmr.encode('Hello world!')
assert tokens.tolist() == [ 0, 35378, 8999, 38, 2]
xlmr.decode(tokens) # 'Hello world!'
```
##### Extract features from XLM-R:
```python
# Extract the last layer's features
last_layer_features = xlmr.extract_features(tokens)
assert last_layer_features.size() == torch.Size([1, 5, 1024])
# Extract all layer's features (layer 0 is the embedding layer)
all_layers = xlmr.extract_features(tokens, return_all_hiddens=True)
assert len(all_layers) == 25
assert torch.all(all_layers[-1] == last_layer_features)
```
## Citation
```bibtex
@article{,
title = {Unsupervised Cross-lingual Representation Learning at Scale},
author = {Alexis Conneau and Kartikay Khandelwal and Naman Goyal
and Vishrav Chaudhary and Guillaume Wenzek and Francisco Guzm\'an
and Edouard Grave and Myle Ott and Luke Zettlemoyer and Veselin Stoyanov
},
journal={},
year = {2019},
}
```

View File

@ -194,6 +194,30 @@ class RobertaModel(FairseqLanguageModel):
state_dict[prefix + 'classification_heads.' + k] = v
@register_model('xlmr')
class XLMRModel(RobertaModel):
@classmethod
def hub_models(cls):
return {
'xlmr.base.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.base.v0.tar.gz',
'xlmr.large.v0': 'http://dl.fbaipublicfiles.com/fairseq/models/xlmr.large.v0.tar.gz',
}
@classmethod
def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', bpe='sentencepiece', **kwargs):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
**kwargs,
)
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
class RobertaLMHead(nn.Module):
"""Head for masked language modeling."""