adding first version of bart code release (#902)

Summary:
This is the first version of BART code / model release.

It still requires lot of clean up, instructions, making sure results are reproducible before we can release it.
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/902

Differential Revision: D18389535

fbshipit-source-id: 77f16800307ce831bd29538fdd34800793210f46
This commit is contained in:
Naman Goyal 2019-11-08 21:00:06 -08:00 committed by Facebook Github Bot
parent e98bf7e64f
commit a92bcdad5a
18 changed files with 1360 additions and 39 deletions

View File

@ -6,6 +6,7 @@ modeling and other text generation tasks.
### What's New:
- November 2019: [BART model and code released](examples/bart/README.md)
- November 2019: [XLM-R models and code released](examples/xlmr/README.md)
- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
- August 2019: [WMT'19 models released](examples/wmt19/README.md)

View File

@ -0,0 +1,99 @@
# Fine-tuning BART on GLUE tasks
### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
```bash
wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
python download_glue_data.py --data_dir glue_data --tasks all
```
### 2) Preprocess GLUE task data (same as RoBERTa):
```bash
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
```
`glue_task_name` is one of the following:
`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
Use `ALL` for preprocessing all the glue tasks.
### 3) Fine-tuning on GLUE task:
Example fine-tuning cmd for `RTE` task
```bash
TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
WARMUP_UPDATES=61 # 6 percent of the number of updates
LR=1e-05 # Peak LR for polynomial LR scheduler.
NUM_CLASSES=2
MAX_SENTENCES=16 # Batch size.
BART_PATH=/path/to/bart/model.pt
CUDA_VISIBLE_DEVICES=0,1 python train.py RTE-bin/ \
--restore-file $BART_PATH \
--max-sentences $MAX_SENTENCES \
--max-tokens 4400 \
--task sentence_prediction \
--add-prev-output-tokens \
--layernorm-embedding \
--share-all-embeddings \
--share-decoder-input-output-embed \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--init-token 0 \
--arch bart_large \
--criterion sentence_prediction \
--num-classes $NUM_CLASSES \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
--clip-norm 0.0 \
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
--max-epoch 10 \
--find-unused-parameters \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
```
For each of the GLUE task, you will need to use following cmd-line arguments:
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
`--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
`bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
`--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
`--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
**Note:**
a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--max-sentences=32/64/128` depending on the task.
b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`.
### Inference on GLUE task
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
```python
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='RTE-bin'
)
label_fn = lambda label: bart.task.label_dictionary.string(
[label + bart.task.label_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
bart.cuda()
bart.eval()
with open('glue_data/RTE/dev.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
tokens = bart.encode(sent1, sent2)
prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_fn(prediction)
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
```

169
examples/bart/README.md Normal file
View File

@ -0,0 +1,169 @@
# BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
[https://arxiv.org/pdf/1910.13461.pdf]
## Introduction
BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) Results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.
## Pre-trained models
Model | Description | # params | Download
---|---|---|---
`bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
`bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
## Results
**[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
_(dev set, single model, single-task finetuning)_
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2
**[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
_(dev set, no additional data used)_
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
---|---|---
`roberta.large` | 88.9/94.6 | 86.5/89.4
`bart.large` | 88.8/94.6 | 86.1/89.2
**[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
_(dev set, no additional data used)_
Model | R1 | R2 | RL
---|---|---|---
`BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
`bart.large` | 44.16 | 21.28 | 40.90
## Example usage
##### Load BART from torch.hub (PyTorch >= 1.1):
```python
import torch
bart = torch.hub.load('pytorch/fairseq', 'bart.large')
bart.eval() # disable dropout (or leave in train mode to finetune)
```
##### Load BART (for PyTorch 1.0 or custom models):
```python
# Download bart.large model
wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
tar -xzvf bart.large.tar.gz
# Load the model in fairseq
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt')
bart.eval() # disable dropout (or leave in train mode to finetune)
```
##### Apply Byte-Pair Encoding (BPE) to input text:
```python
tokens = bart.encode('Hello world!')
assert tokens.tolist() == [0, 31414, 232, 328, 2]
bart.decode(tokens) # 'Hello world!'
```
##### Extract features from BART:
```python
# Extract the last layer's features
last_layer_features = bart.extract_features(tokens)
assert last_layer_features.size() == torch.Size([1, 5, 1024])
# Extract all layer's features from decoder (layer 0 is the embedding layer)
all_layers = bart.extract_features(tokens, return_all_hiddens=True)
assert len(all_layers) == 13
assert torch.all(all_layers[-1] == last_layer_features)
```
##### Use BART for sentence-pair classification tasks:
```python
# Download BART already finetuned for MNLI
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
bart.eval() # disable dropout for evaluation
# Encode a pair of sentences and make a prediction
tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
bart.predict('mnli', tokens).argmax() # 0: contradiction
# Encode another pair of sentences
tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
bart.predict('mnli', tokens).argmax() # 2: entailment
```
##### Register a new (randomly initialized) classification head:
```python
bart.register_classification_head('new_task', num_classes=3)
logprobs = bart.predict('new_task', tokens)
```
##### Batched prediction:
```python
import torch
from fairseq.data.data_utils import collate_tokens
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
bart.eval()
batch_of_pairs = [
['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
]
batch = collate_tokens(
[bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
)
logprobs = bart.predict('mnli', batch)
print(logprobs.argmax(dim=1))
# tensor([0, 2])
```
##### Using the GPU:
```python
bart.cuda()
bart.predict('new_task', tokens)
```
#### Evaluating the `bart.large.mnli` model:
Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
```python
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
ncorrect, nsamples = 0, 0
bart.cuda()
bart.eval()
with open('glue_data/MNLI/dev_matched.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
tokens = bart.encode(sent1, sent2)
prediction = bart.predict('mnli', tokens).argmax().item()
prediction_label = label_map[prediction]
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
# Expected output: 0.9010
```
## Finetuning
- [Finetuning on GLUE](README.glue.md)
## Citation
```bibtex
@article{lewis2019bart,
title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
Language Generation, Translation, and Comprehension},
author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
and Luke Zettlemoyer },
journal={arXiv preprint arXiv:1910.13461},
year = {2019},
}
```

View File

@ -79,7 +79,7 @@ roberta = RobertaModel.from_pretrained(
)
label_fn = lambda label: roberta.task.label_dictionary.string(
[label + roberta.task.target_dictionary.nspecial]
[label + roberta.task.label_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
roberta.cuda()

View File

@ -9,11 +9,13 @@ from .fairseq_dataset import FairseqDataset
from .base_wrapper_dataset import BaseWrapperDataset
from .append_token_dataset import AppendTokenDataset
from .audio.raw_audio_dataset import FileAudioDataset
from .backtranslation_dataset import BacktranslationDataset
from .colorize_dataset import ColorizeDataset
from .concat_dataset import ConcatDataset
from .concat_sentences_dataset import ConcatSentencesDataset
from .denoising_dataset import DenoisingDataset
from .id_dataset import IdDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset
@ -33,6 +35,7 @@ from .prepend_token_dataset import PrependTokenDataset
from .raw_label_dataset import RawLabelDataset
from .replace_dataset import ReplaceDataset
from .resampling_dataset import ResamplingDataset
from .roll_dataset import RollDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sharded_dataset import ShardedDataset
from .sort_dataset import SortDataset
@ -42,7 +45,6 @@ from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
from .truncate_dataset import TruncateDataset
from .resampling_dataset import ResamplingDataset
from .iterators import (
CountingIterator,
@ -52,12 +54,14 @@ from .iterators import (
)
__all__ = [
'AppendTokenDataset',
'BacktranslationDataset',
'BaseWrapperDataset',
'ColorizeDataset',
'ConcatDataset',
'ConcatSentencesDataset',
'CountingIterator',
'DenoisingDataset',
'Dictionary',
'EpochBatchIterator',
'FairseqDataset',
@ -83,9 +87,10 @@ __all__ = [
'PrependDataset',
'PrependTokenDataset',
'ReplaceDataset',
'RollDataset',
'FileAudioDataset',
'RawLabelDataset',
'ResamplingDataset'
'ResamplingDataset',
'RightPadDataset',
'RoundRobinZipDatasets',
'ShardedDataset',

View File

@ -0,0 +1,42 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from . import BaseWrapperDataset
class AppendTokenDataset(BaseWrapperDataset):
def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
if token is not None:
self._sizes = np.array(dataset.sizes) + 1
else:
self._sizes = dataset.sizes
def __getitem__(self, idx):
item = self.dataset[idx]
if self.token is not None:
item = torch.cat([item, item.new([self.token])])
return item
@property
def sizes(self):
return self._sizes
def num_tokens(self, index):
n = self.dataset.num_tokens(index)
if self.token is not None:
n += 1
return n
def size(self, index):
n = self.dataset.size(index)
if self.token is not None:
n += 1
return n

View File

@ -0,0 +1,386 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import math
from . import data_utils, FairseqDataset
def collate(
samples,
pad_idx,
eos_idx,
vocab,
left_pad_source=False,
left_pad_target=False,
input_feeding=True,
):
assert input_feeding
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=left_pad_source)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
if input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=left_pad_target,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else:
ntokens = sum(len(s['source']) for s in samples)
batch = {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
},
'target': target,
'nsentences': samples[0]['source'].size(0),
}
if prev_output_tokens is not None:
batch['net_input']['prev_output_tokens'] = prev_output_tokens
return batch
class DenoisingDataset(FairseqDataset):
"""
A wrapper around TokenBlockDataset for BART dataset.
Args:
dataset (TokenBlockDataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary
mask_idx (int): dictionary index used for masked token
mask_whole_words: only mask whole words. This should be a byte mask
over vocab indices, indicating whether it is the beginning of a
word. We will extend any mask to encompass the whole word.
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
seed: Seed for random number generator for reproducibility.
args: argparse arguments.
"""
def __init__(
self,
dataset,
sizes,
vocab,
mask_idx,
mask_whole_words,
shuffle,
seed,
args
):
self.dataset = dataset
self.sizes = sizes
self.vocab = vocab
self.shuffle = shuffle
self.seed = seed
self.mask_idx = mask_idx
self.mask_whole_word = mask_whole_words
self.mask_ratio = args.mask
self.random_ratio = args.mask_random
self.insert_ratio = args.insert
self.rotate_ratio = args.rotate
self.permute_sentence_ratio = args.permute_sentences
if args.bpe != 'gpt2':
self.full_stop_index = self.vocab.index(".")
else:
assert args.bpe == 'gpt2'
self.full_stop_index = self.vocab.index('13')
self.replace_length = args.replace_length
if not self.replace_length in [-1, 0, 1]:
raise (f'invalid arg: replace_length={self.replace_length}')
if not args.mask_length in ['subword', 'word', 'span-poisson']:
raise (f'invalid arg: mask-length={args.mask_length}')
if args.mask_length == 'subword' and not args.replace_length in [0, 1]:
raise (f'if using subwords, use replace-length=1 or 0')
self.mask_span_distribution = None
if args.mask_length == 'span-poisson':
_lambda = args.poisson_lambda
lambda_to_the_k = 1
e_to_the_minus_lambda = math.exp(-_lambda)
k_factorial = 1
ps = []
for k in range(0, 128):
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
lambda_to_the_k *= _lambda
k_factorial *= (k + 1)
if ps[-1] < 0.0000001:
break
ps = torch.FloatTensor(ps)
self.mask_span_distribution = torch.distributions.Categorical(ps)
self.verbose = args.verbose
self.epoch = 0
def set_epoch(self, epoch, **unused):
self.epoch = epoch
def __getitem__(self, index):
with data_utils.numpy_seed(self.seed, self.epoch, index):
tokens = self.dataset[index]
assert tokens[-1] == self.vocab.eos()
source, target = tokens, tokens.clone()
if self.permute_sentence_ratio > 0.0:
source = self.permute_sentences(source, self.permute_sentence_ratio)
if self.mask_ratio > 0:
source = self.add_whole_word_mask(source, self.mask_ratio)
if self.insert_ratio > 0:
source = self.add_insertion_noise(source, self.insert_ratio)
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
source = self.add_rolling_noise(source)
assert (source >= 0).all()
assert (source[1:-1] >= 1).all()
assert (source <= len(self.vocab)).all()
assert source[0] == self.vocab.bos()
assert source[-1] == self.vocab.eos()
return {
'id': index,
'source': source,
'target': target,
}
def __len__(self):
return len(self.dataset)
def permute_sentences(self, source, p=1.0):
full_stops = (source == self.full_stop_index)
# Pretend it ends with a full stop so last span is a sentence
full_stops[-2] = 1
# Tokens that are full stops, where the previous token is not
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero() + 2
result = source.clone()
num_sentences = sentence_ends.size(0)
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
substitutions = torch.randperm(num_sentences)[:num_to_permute]
ordering = torch.arange(0, num_sentences)
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
# Ignore <bos> at start
index = 1
for i in ordering:
sentence = source[(sentence_ends[i - 1] if i > 0 else 1):sentence_ends[i]]
result[index:index + sentence.size(0)] = sentence
index += sentence.size(0)
return result
def word_starts(self, source):
if self.mask_whole_word is not None:
is_word_start = self.mask_whole_word.gather(0, source)
else:
is_word_start = torch.ones(source.size())
is_word_start[0] = 0
is_word_start[-1] = 0
return is_word_start
def add_whole_word_mask(self, source, p):
is_word_start = self.word_starts(source)
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
num_inserts = 0
if num_to_mask == 0:
return source
if self.mask_span_distribution is not None:
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
# Make sure we have enough to mask
cum_length = torch.cumsum(lengths, 0)
while cum_length[-1] < num_to_mask:
lengths = torch.cat([lengths, self.mask_span_distribution.sample(sample_shape=(num_to_mask,))], dim=0)
cum_length = torch.cumsum(lengths, 0)
# Trim to masking budget
i = 0
while cum_length[i] < num_to_mask:
i += 1
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
num_to_mask = i + 1
lengths = lengths[:num_to_mask]
# Handle 0-length mask (inserts) separately
lengths = lengths[lengths > 0]
num_inserts = num_to_mask - lengths.size(0)
num_to_mask -= num_inserts
if num_to_mask == 0:
return self.add_insertion_noise(source, num_inserts / source.size(0))
assert (lengths > 0).all()
else:
lengths = torch.ones((num_to_mask,)).long()
assert is_word_start[-1] == 0
word_starts = is_word_start.nonzero()
indices = word_starts[torch.randperm(word_starts.size(0))[:num_to_mask]].squeeze(1)
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
source_length = source.size(0)
assert source_length - 1 not in indices
to_keep = torch.ones(source_length, dtype=torch.bool)
is_word_start[-1] = 255 # acts as a long length, so spans don't go over the end of doc
if self.replace_length == 0:
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),))
if self.mask_span_distribution is not None:
assert len(lengths.size()) == 1
assert lengths.size() == indices.size()
lengths -= 1
while indices.size(0) > 0:
assert lengths.size() == indices.size()
lengths -= is_word_start[indices + 1].long()
uncompleted = lengths >= 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
lengths = lengths[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),))
else:
# A bit faster when all lengths are 1
while indices.size(0) > 0:
uncompleted = is_word_start[indices + 1] == 0
indices = indices[uncompleted] + 1
mask_random = mask_random[uncompleted]
if self.replace_length != -1:
# delete token
to_keep[indices] = 0
else:
# keep index, but replace it with [MASK]
source[indices] = self.mask_idx
source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),))
assert source_length - 1 not in indices
source = source[to_keep]
if num_inserts > 0:
source = self.add_insertion_noise(source, num_inserts / source.size(0))
return source
def add_permuted_noise(self, tokens, p):
num_words = len(tokens)
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
return tokens
def add_rolling_noise(self, tokens):
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
tokens = torch.cat(
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
dim=0,
)
return tokens
def add_insertion_noise(self, tokens, p):
if p == 0.0:
return tokens
num_tokens = len(tokens)
n = int(math.ceil(num_tokens * p))
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
noise_mask[noise_indices] = 1
result = torch.LongTensor(n + len(tokens)).fill_(-1)
num_random = int(math.ceil(n * self.random_ratio))
result[noise_indices[num_random:]] = self.mask_idx
result[noise_indices[:num_random]] = torch.randint(low=1, high=len(self.vocab), size=(num_random,))
result[~noise_mask] = tokens
assert (result >= 0).all()
return result
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch of data
"""
return collate(samples, self.vocab.pad(), self.vocab.eos(), self.vocab)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return self.sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.sizes[index]
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
return indices[np.argsort(self.sizes[indices], kind='mergesort')]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, 'supports_prefetch')
and self.src.supports_prefetch
and hasattr(self.tgt, 'supports_prefetch')
and self.tgt.supports_prefetch
)

View File

@ -0,0 +1,28 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from fairseq.data import encoders
def get_whole_word_mask(args, dictionary):
bpe = encoders.build_bpe(args)
if bpe is not None:
def is_beginning_of_word(i):
if i < dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = dictionary[i]
if tok.startswith('madeupword'):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
mask_whole_words = torch.ByteTensor(list(
map(is_beginning_of_word, range(len(dictionary)))
))
return mask_whole_words
return None

View File

@ -0,0 +1,19 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import BaseWrapperDataset
class RollDataset(BaseWrapperDataset):
def __init__(self, dataset, shifts):
super().__init__(dataset)
self.shifts = shifts
def __getitem__(self, index):
item = self.dataset[index]
return torch.roll(item, self.shifts)

View File

@ -0,0 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .hub_interface import * # noqa
from .model import * # noqa

View File

@ -0,0 +1,118 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data import encoders
class BARTHubInterface(nn.Module):
"""A simple PyTorch Hub interface to BART.
Usage: https://github.com/pytorch/fairseq/tree/master/examples/BART
"""
def __init__(self, args, task, model):
super().__init__()
self.args = args
self.task = task
self.model = model
self.bpe = encoders.build_bpe(args)
# this is useful for determining the device
self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
@property
def device(self):
return self._float_tensor.device
def encode(self, sentence: str, *addl_sentences, no_separator=True) -> torch.LongTensor:
"""
BPE-encode a sentence (or multiple sentences).
Every sequence begins with a beginning-of-sentence (`<s>`) symbol.
Every sentence ends with an end-of-sentence (`</s>`).
Example (single sentence): `<s> a b c </s>`
Example (sentence pair): `<s> d e f </s> 1 2 3 </s>`
The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE
requires leading spaces. For example::
>>> bart.encode('Hello world').tolist()
[0, 31414, 232, 2]
>>> bart.encode(' world').tolist()
[0, 232, 2]
>>> bart.encode('world').tolist()
[0, 8331, 2]
"""
bpe_sentence = '<s> ' + self.bpe.encode(sentence) + ' </s>'
for s in addl_sentences:
bpe_sentence += (' </s>' if not no_separator else '')
bpe_sentence += ' ' + self.bpe.encode(s) + ' </s>'
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)
return tokens.long()
def decode(self, tokens: torch.LongTensor):
assert tokens.dim() == 1
tokens = tokens.numpy()
if tokens[0] == self.task.source_dictionary.bos():
tokens = tokens[1:] # remove <s>
eos_mask = (tokens == self.task.source_dictionary.eos())
doc_mask = eos_mask[1:] & eos_mask[:-1]
sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
sentences = [self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences]
if len(sentences) == 1:
return sentences[0]
return sentences
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
if tokens.size(-1) > min(self.model.max_positions()):
raise ValueError('tokens exceeds maximum length: {} > {}'.format(
tokens.size(-1), self.model.max_positions()
))
tokens.to(device=self.device),
prev_output_tokens = tokens.clone()
prev_output_tokens[:, 0] = tokens[:, -1]
prev_output_tokens[:, 1:] = tokens[:, :-1]
features, extra = self.model(
src_tokens=tokens,
src_lengths=None,
prev_output_tokens=prev_output_tokens,
features_only=True,
return_all_hiddens=return_all_hiddens,
)
if return_all_hiddens:
# convert from T x B x C -> B x T x C
inner_states = extra['inner_states']
return [inner_state.transpose(0, 1) for inner_state in inner_states]
else:
return features # just the last layer's features
def register_classification_head(
self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
):
self.model.register_classification_head(
name, num_classes=num_classes, embedding_size=embedding_size, **kwargs
)
def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
features = self.extract_features(tokens.to(device=self.device))
sentence_representation = features[
tokens.eq(self.task.source_dictionary.eos()), :
].view(features.size(0), -1, features.size(-1))[:, -1, :]
logits = self.model.classification_heads[head](sentence_representation)
if return_logits:
return logits
return F.log_softmax(logits, dim=-1)

View File

@ -0,0 +1,246 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""
import torch.nn as nn
from fairseq import utils
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.models.transformer import TransformerModel
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from .hub_interface import BARTHubInterface
@register_model('bart')
class BARTModel(TransformerModel):
@classmethod
def hub_models(cls):
return {
'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz',
'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz',
}
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)
# We follow BERT's random weight initialization
self.apply(init_bert_params)
self.classification_heads = nn.ModuleDict()
@staticmethod
def add_args(parser):
super(BARTModel, BARTModel).add_args(parser)
parser.add_argument(
'--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence'
)
parser.add_argument(
'--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence'
)
parser.add_argument(
'--pooler-dropout', type=float, metavar='D',
help='dropout probability in the masked_lm pooler layers'
)
parser.add_argument(
'--pooler-activation-fn',
choices=utils.get_available_activation_fns(),
help='activation function to use for pooler layer'
)
@property
def supported_targets(self):
return {'self'}
def forward(
self, src_tokens, src_lengths, prev_output_tokens,
features_only=False, classification_head_name=None, **kwargs
):
if classification_head_name is not None:
features_only = True
encoder_out = self.encoder(
src_tokens,
src_lengths=src_lengths,
**kwargs,
)
x, extra = self.decoder(
prev_output_tokens,
encoder_out=encoder_out,
features_only=features_only,
**kwargs,
)
if classification_head_name is not None:
sentence_representation = x[
src_tokens.eq(self.encoder.dictionary.eos()), :
].view(x.size(0), -1, x.size(-1))[:, -1, :]
x = self.classification_heads[classification_head_name](
sentence_representation
)
return x, extra
@classmethod
def from_pretrained(
cls,
model_name_or_path,
checkpoint_file='model.pt',
data_name_or_path='.',
bpe='gpt2',
**kwargs,
):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
**kwargs,
)
return BARTHubInterface(x['args'], x['task'], x['models'][0])
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
"""Register a classification head."""
print("Registering classification head: {0}".format(name))
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
print(
'WARNING: re-registering head "{}" with num_classes {} (prev: {}) '
'and inner_dim {} (prev: {})'.format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = BARTClassificationHead(
self.args.encoder_embed_dim,
inner_dim or self.args.encoder_embed_dim,
num_classes,
self.args.pooler_activation_fn,
self.args.pooler_dropout,
)
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
prefix = name + '.' if name != '' else ''
current_head_names = [] if not hasattr(self, 'classification_heads') else \
self.classification_heads.keys()
# Handle new classification heads present in the state dict.
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'):
continue
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
if getattr(self.args, 'load_checkpoint_heads', False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
print(
'WARNING: deleting classification head ({}) from checkpoint '
'not present in current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes != self.classification_heads[head_name].out_proj.out_features
or inner_dim != self.classification_heads[head_name].dense.out_features
):
print(
'WARNING: deleting classification head ({}) from checkpoint '
'with different dimensions than current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, 'classification_heads'):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + 'classification_heads.' + k not in state_dict:
print('Overwriting', prefix + 'classification_heads.' + k)
state_dict[prefix + 'classification_heads.' + k] = v
class BARTClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim,
inner_dim,
num_classes,
activation_fn,
pooler_dropout,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = nn.Linear(inner_dim, num_classes)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
@register_model_architecture('bart', 'bart_large')
def bart_large_architecture(args):
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*1024)
args.encoder_layers = getattr(args, 'encoder_layers', 12)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True)
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 12)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True)
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1)
args.max_target_positions = getattr(args, 'max_target_positions', 1024)
args.max_source_positions = getattr(args, 'max_source_positions', 1024)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', True)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', True)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, 'no_scale_embedding', True)
args.layernorm_embedding = getattr(args, 'layernorm_embedding', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu')
args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh')
args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0)

View File

@ -67,8 +67,9 @@ class TransformerModel(FairseqEncoderDecoderModel):
}
# fmt: on
def __init__(self, encoder, decoder):
def __init__(self, args, encoder, decoder):
super().__init__(encoder, decoder)
self.args = args
self.supports_align_args = True
@staticmethod
@ -140,6 +141,10 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--layernorm-embedding', action='store_true',
help='add layernorm to embedding')
parser.add_argument('--no-scale-embedding', action='store_true',
help='if True, dont scale embeddings')
# fmt: on
@classmethod
@ -195,7 +200,7 @@ class TransformerModel(FairseqEncoderDecoderModel):
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
return cls(encoder, decoder)
return cls(args, encoder, decoder)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
@ -219,7 +224,7 @@ class TransformerAlignModel(TransformerModel):
"""
def __init__(self, encoder, decoder, args):
super().__init__(encoder, decoder)
super().__init__(args, encoder, decoder)
self.alignment_heads = args.alignment_heads
self.alignment_layer = args.alignment_layer
self.full_context_alignment = args.full_context_alignment
@ -297,7 +302,9 @@ class TransformerEncoder(FairseqEncoder):
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
args.max_source_positions, embed_dim, self.padding_idx,
learned=args.encoder_learned_pos,
@ -315,17 +322,22 @@ class TransformerEncoder(FairseqEncoder):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
if getattr(args, 'layernorm_embedding', False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
def forward_embedding(self, src_tokens):
# embed tokens and positions
embed = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding:
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
return x, embed
def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False):
def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False, **unused):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
@ -422,6 +434,7 @@ class TransformerEncoder(FairseqEncoder):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
print('deleting {0}'.format(weights_key))
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)):
@ -466,7 +479,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None
@ -507,6 +521,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
if getattr(args, 'layernorm_embedding', False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
def forward(
self,
@ -590,6 +608,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if positions is not None:
x += positions
if self.layernorm_embedding:
x = self.layernorm_embedding(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
@ -758,6 +780,9 @@ def base_architecture(args):
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, 'no_scale_embedding', False)
args.layernorm_embedding = getattr(args, 'layernorm_embedding', False)
@register_model_architecture('transformer', 'transformer_iwslt_de_en')
def transformer_iwslt_de_en(args):

View File

@ -103,6 +103,10 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='LayerDrop probability for decoder')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--layernorm-embedding', action='store_true',
help='add layernorm to embedding')
parser.add_argument('--no-scale-embedding', action='store_true',
help='if True, dont scale embeddings')
# fmt: on
@classmethod
@ -190,6 +194,9 @@ def base_lm_architecture(args):
args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False)
args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False)
args.no_scale_embedding = getattr(args, 'no_scale_embedding', False)
args.layernorm_embedding = getattr(args, 'layernorm_embedding', False)
@register_model_architecture('transformer_lm', 'transformer_lm_big')
def transformer_lm_big(args):

162
fairseq/tasks/denoising.py Normal file
View File

@ -0,0 +1,162 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from fairseq.data import (
data_utils,
Dictionary,
AppendTokenDataset,
DenoisingDataset,
PrependTokenDataset,
StripTokenDataset,
TokenBlockDataset,
)
from fairseq.data.encoders.utils import get_whole_word_mask
from . import FairseqTask, register_task
@register_task('denoising')
class DenoisingTask(FairseqTask):
"""
Denoising task for applying sequence to sequence denoising. (ie. BART)
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', help='path to data directory')
parser.add_argument('--tokens-per-sample', default=512, type=int,
help='max number of total tokens over all segments'
' per sample for dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument(
'--sample-break-mode', default="complete_doc", type=str,
help='mode for breaking sentence',
)
parser.add_argument(
'--mask', default=0.0, type=float,
help='fraction of words/subwords that will be masked',
)
parser.add_argument(
'--mask-random', default=0.0, type=float,
help='instead of using [MASK], use random token this often'
)
parser.add_argument(
'--insert', default=0.0, type=float,
help='insert this percentage of additional random tokens',
)
parser.add_argument(
'--permute', default=0.0, type=float,
help='take this proportion of subwords and permute them',
)
parser.add_argument(
'--rotate', default=0.5, type=float,
help='rotate this proportion of inputs',
)
parser.add_argument(
'--poisson-lambda', default=3.0, type=float,
help='randomly shuffle sentences for this proportion of inputs'
)
parser.add_argument(
'--permute-sentences', default=0.0, type=float,
help='shuffle this proportion of sentences in all inputs'
)
parser.add_argument(
'--mask-length', default="subword", type=str,
choices=['subword', 'word', 'span-possion'],
help='mask length to choose'
)
parser.add_argument(
'--replace-length', default=-1, type=int,
help='when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)'
)
def __init__(self, args, dictionary):
super().__init__(args)
self.dictionary = dictionary
self.seed = args.seed
# add mask token
self.mask_idx = self.dictionary.add_symbol('<mask>')
@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task.
"""
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary)))
if not hasattr(args, 'shuffle_instance'):
args.shuffle_instance = False
return cls(args, dictionary)
def load_dataset(self, split, epoch=0, combine=False):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(
split_path,
self.dictionary,
self.args.dataset_impl,
combine=combine,
)
if dataset is None:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))
dataset = StripTokenDataset(dataset, self.dictionary.eos())
# create continuous blocks of tokens
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.args.tokens_per_sample - 2, # one less for <s> and one for </s>
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode,
document_sep_len=0
)
# prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())
mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
if self.args.mask_length != 'subword' else None
self.datasets[split] = DenoisingDataset(
dataset, dataset.sizes, self.dictionary, self.mask_idx,
mask_whole_words, shuffle=self.args.shuffle_instance,
seed=self.seed, args=self.args
)
print(
"| Split: {0}, Loaded {1} samples of denoising_dataset".format(
split,
len(self.datasets[split]),
)
)
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions)
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.dictionary
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.dictionary

View File

@ -7,7 +7,12 @@ import numpy as np
import torch
from fairseq import tokenizer
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary
from fairseq.data import (
data_utils,
FairseqDataset,
iterators,
Dictionary,
)
class FairseqTask(object):

View File

@ -6,12 +6,10 @@
import os
import numpy as np
import torch
from fairseq.data import (
data_utils,
Dictionary,
encoders,
IdDataset,
MaskTokensDataset,
NestedDictionaryDataset,
@ -23,6 +21,7 @@ from fairseq.data import (
TokenBlockDataset,
)
from fairseq.tasks import FairseqTask, register_task
from fairseq.data.encoders.utils import get_whole_word_mask
@register_task('masked_lm')
@ -106,27 +105,8 @@ class MaskedLMTask(FairseqTask):
dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
# create masked input and targets
if self.args.mask_whole_words:
bpe = encoders.build_bpe(self.args)
assert bpe is not None
def is_beginning_of_word(i):
if i < self.source_dictionary.nspecial:
# special elements are always considered beginnings
return True
tok = self.source_dictionary[i]
if tok.startswith('madeupword'):
return True
try:
return bpe.is_beginning_of_word(tok)
except ValueError:
return True
mask_whole_words = torch.ByteTensor(list(
map(is_beginning_of_word, range(len(self.source_dictionary)))
))
else:
mask_whole_words = None
mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
if self.args.mask_whole_words else None
src_dataset, tgt_dataset = MaskTokensDataset.apply_mask(
dataset,

View File

@ -19,6 +19,7 @@ from fairseq.data import (
PrependTokenDataset,
RawLabelDataset,
RightPadDataset,
RollDataset,
SortDataset,
StripTokenDataset,
TruncateDataset,
@ -51,11 +52,21 @@ class SentencePredictionTask(FairseqTask):
parser.add_argument('--no-shuffle', action='store_true', default=False)
parser.add_argument('--truncate-sequence', action='store_true', default=False,
help='Truncate sequence to max_sequence_length')
parser.add_argument('--add-prev-output-tokens', action='store_true', default=False,
help='Add prev_output_tokens to sample, used for encoder-decoder arch')
def __init__(self, args, data_dictionary, label_dictionary):
super().__init__(args)
self.dictionary = data_dictionary
self.label_dictionary = label_dictionary
self._label_dictionary = label_dictionary
if not hasattr(args, 'max_positions'):
self._max_positions = (
args.max_source_positions,
args.max_target_positions,
)
else:
self._max_positions = args.max_positions
args.tokens_per_sample = self._max_positions
@classmethod
def load_dictionary(cls, args, filename, source=True):
@ -72,8 +83,6 @@ class SentencePredictionTask(FairseqTask):
def setup_task(cls, args, **kwargs):
assert args.num_classes > 0, 'Must set --num-classes'
args.tokens_per_sample = args.max_positions
# load data dictionary
data_dict = cls.load_dictionary(
args,
@ -145,6 +154,15 @@ class SentencePredictionTask(FairseqTask):
'ntokens': NumelDataset(src_tokens, reduce=True),
}
if self.args.add_prev_output_tokens:
prev_tokens_dataset = RightPadDataset(
RollDataset(src_tokens, 1),
pad_idx=self.dictionary.pad(),
)
dataset['net_input'].update(
prev_output_tokens=prev_tokens_dataset,
)
if not self.args.regression_target:
label_dataset = make_dataset('label', self.target_dictionary)
if label_dataset is not None:
@ -197,7 +215,7 @@ class SentencePredictionTask(FairseqTask):
return model
def max_positions(self):
return self.args.max_positions
return self._max_positions
@property
def source_dictionary(self):
@ -205,4 +223,8 @@ class SentencePredictionTask(FairseqTask):
@property
def target_dictionary(self):
return self.label_dictionary
return self.dictionary
@property
def label_dictionary(self):
return self._label_dictionary