Add RoBERTa

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/916

Differential Revision: D16537774

Pulled By: myleott

fbshipit-source-id: 86bb7b1913a428ee4a21674cc3fc7b39264067ec
This commit is contained in:
Myle Ott 2019-07-28 18:40:05 -07:00 committed by Facebook Github Bot
parent a80cade964
commit 8d036c2fe0
4 changed files with 25 additions and 121 deletions

View File

@ -1,9 +1,18 @@
# Introduction <img src="fairseq_logo.png" width="50">
# <img src="fairseq_logo.png" width="30"> Introduction
Fairseq(-py) is a sequence modeling toolkit that allows researchers and
developers to train custom models for translation, summarization, language
modeling and other text generation tasks. It provides reference implementations
of various sequence-to-sequence models, including:
modeling and other text generation tasks.
### What's New:
- July 2019: [RoBERTa models and code release](examples/roberta/README.md)
- June 2019: [wav2vec models and code release](examples/wav2vec/README.md)
- April 2019: [fairseq demo paper @ NAACL 2019](https://arxiv.org/abs/1904.01038)
### Features:
Fairseq provides reference implementations of various sequence-to-sequence models, including:
- **Convolutional Neural Networks (CNN)**
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/language_model/conv_lm/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
@ -11,18 +20,18 @@ of various sequence-to-sequence models, including:
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- **_New_** [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
- **LightConv and DynamicConv models**
- **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)
- **Long Short-Term Memory (LSTM) networks**
- [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025)
- [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960)
- Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation
- **Transformer (self-attention) networks**
- [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- Vaswani et al. (2017): Attention Is All You Need
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](examples/backtranslation/README.md)
- **_New_** [Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](examples/language_model/transformer_lm/README.md)
- **_New_** [Shen et al. (2019): Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- [Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](examples/language_model/transformer_lm/README.md)
- [Shen et al. (2019): Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- **_New_** [Liu et al. (2019): RoBERTa: A Robustly Optimized BERT Pretraining Approach](examples/roberta/README.md)
Fairseq features:
**Additionally:**
- multi-GPU (distributed) training on one machine or across multiple machines
- fast generation on both CPU and GPU with multiple search algorithms implemented:
- beam search
@ -83,6 +92,7 @@ as well as example training and evaluation commands.
- [Language Modeling](examples/language_model/README.md): convolutional models are available
We also have more detailed READMEs to reproduce results from specific papers:
- [Liu et al. (2019): RoBERTa: A Robustly Optimized BERT Pretraining Approach](examples/roberta/README.md)
- [Schneider et al. (2019): wav2vec: Unsupervised Pre-training for Speech Recognition](examples/wav2vec/README.md)
- [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md)

View File

@ -1,6 +1,6 @@
# RoBERTa: A Robustly Optimized BERT Pretraining Approach
*Pre-print coming 7/28*
https://arxiv.org/abs/1907.11692
## Introduction
@ -144,6 +144,7 @@ A more detailed tutorial is coming soon.
author = {Yinhan Liu and Myle Ott and Naman Goyal and Jingfei Du and
Mandar Joshi and Danqi Chen and Omer Levy and Mike Lewis and
Luke Zettlemoyer and Veselin Stoyanov},
journal={arXiv preprint arXiv:1907.11692},
year = {2019},
}
```

View File

@ -123,8 +123,8 @@ def register_model_architecture(model_name, arch_name):
# automatically import any Python files in the models/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
model_name = file[:file.find('.py')]
if not file.startswith('_'):
model_name = file[:file.find('.py')] if file.endswith('.py') else file
module = importlib.import_module('fairseq.models.' + model_name)
# extra `model_parser` for sphinx

View File

@ -13,7 +13,6 @@ import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import encoders
from fairseq.models import (
FairseqDecoder,
FairseqLanguageModel,
@ -26,113 +25,7 @@ from fairseq.modules import (
)
from fairseq.modules.transformer_sentence_encoder import init_bert_params
class RobertaHubInterface(nn.Module):
"""A simple PyTorch Hub interface to RoBERTa.
Load RoBERTa::
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune)
Apply Byte-Pair Encoding (BPE) to input text::
>>> tokens = roberta.encode('Hello world!')
>>> tokens
tensor([ 0, 31414, 232, 328, 2])
Extract features from RoBERTa::
>>> last_layer_features = roberta.extract_features(tokens)
>>> last_layer_features.size()
torch.Size([1, 5, 1024])
>>> all_layers = roberta.extract_features(tokens, return_all_hiddens=True)
>>> len(all_layers)
25
>>> torch.all(all_layers[-1] == last_layer_features)
tensor(1, dtype=torch.uint8)
Use RoBERTa for sentence-pair classification tasks::
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli') # already finetuned
>>> roberta.eval() # disable dropout for evaluation
>>> tokens = roberta.encode(
... 'Roberta is a heavily optimized version of BERT.',
... 'Roberta is not very optimized.'
... )
>>> roberta.predict('mnli', tokens).argmax()
tensor(0) # contradiction
>>> tokens = roberta.encode(
... 'Roberta is a heavily optimized version of BERT.',
... 'Roberta is based on BERT.'
... )
>>> roberta.predict('mnli', tokens).argmax()
tensor(2) # entailment
Register a new (randomly initialized) classification head::
>>> roberta.register_classification_head('new_task', num_classes=3)
>>> roberta.predict('new_task', tokens)
tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)
Using the GPU::
>>> roberta.cuda()
>>> roberta.predict('new_task', tokens)
tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
"""
def __init__(self, args, task, model):
super().__init__()
self.args = args
self.task = task
self.model = model
self.bpe = encoders.build_bpe(args)
# this is useful for determining the device
self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
def encode(self, sentence: str, *addl_sentences) -> torch.LongTensor:
bpe_sentence = '<s> ' + self.bpe.encode(sentence) + ' </s>'
for s in addl_sentences:
bpe_sentence += ' </s> ' + self.bpe.encode(s)
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=True)
return tokens.long()
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
features, extra = self.model(
tokens.to(device=self.device),
features_only=True,
return_all_hiddens=return_all_hiddens,
)
if return_all_hiddens:
# convert from T x B x C -> B x T x C
inner_states = extra['inner_states']
return [inner_state.transpose(0, 1) for inner_state in inner_states]
else:
return features # just the last layer's features
def register_classification_head(
self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
):
self.model.register_classification_head(
name, num_classes=num_classes, embedding_size=embedding_size, **kwargs
)
def predict(self, head: str, tokens: torch.LongTensor):
features = self.extract_features(tokens)
logits = self.model.classification_heads[head](features)
return F.log_softmax(logits, dim=-1)
from .hub_interface import RobertaHubInterface
@register_model('roberta')