diff --git a/README.md b/README.md index 92d7edfc4..07f47f91b 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ modeling and other text generation tasks. ### What's New: +- November 2019: [BART model and code released](examples/bart/README.md) - November 2019: [XLM-R models and code released](examples/xlmr/README.md) - September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md) - August 2019: [WMT'19 models released](examples/wmt19/README.md) diff --git a/examples/bart/README.glue.md b/examples/bart/README.glue.md new file mode 100644 index 000000000..797fdee31 --- /dev/null +++ b/examples/bart/README.glue.md @@ -0,0 +1,99 @@ +# Fine-tuning BART on GLUE tasks + +### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands: +```bash +wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py +python download_glue_data.py --data_dir glue_data --tasks all +``` + +### 2) Preprocess GLUE task data (same as RoBERTa): +```bash +./examples/roberta/preprocess_GLUE_tasks.sh glue_data +``` +`glue_task_name` is one of the following: +`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}` +Use `ALL` for preprocessing all the glue tasks. + +### 3) Fine-tuning on GLUE task: +Example fine-tuning cmd for `RTE` task +```bash +TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16 +WARMUP_UPDATES=61 # 6 percent of the number of updates +LR=1e-05 # Peak LR for polynomial LR scheduler. +NUM_CLASSES=2 +MAX_SENTENCES=16 # Batch size. +BART_PATH=/path/to/bart/model.pt + +CUDA_VISIBLE_DEVICES=0,1 python train.py RTE-bin/ \ + --restore-file $BART_PATH \ + --max-sentences $MAX_SENTENCES \ + --max-tokens 4400 \ + --task sentence_prediction \ + --add-prev-output-tokens \ + --layernorm-embedding \ + --share-all-embeddings \ + --share-decoder-input-output-embed \ + --reset-optimizer --reset-dataloader --reset-meters \ + --required-batch-size-multiple 1 \ + --init-token 0 \ + --arch bart_large \ + --criterion sentence_prediction \ + --num-classes $NUM_CLASSES \ + --dropout 0.1 --attention-dropout 0.1 \ + --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \ + --clip-norm 0.0 \ + --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ + --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ + --max-epoch 10 \ + --find-unused-parameters \ + --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; +``` + +For each of the GLUE task, you will need to use following cmd-line arguments: + +Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B +---|---|---|---|---|---|---|---|--- +`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1 +`--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5 +`bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32 +`--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799 +`--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107 + +For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`. + +**Note:** + +a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--max-sentences=32/64/128` depending on the task. + +b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`. + +### Inference on GLUE task +After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet: + +```python +from fairseq.models.bart import BARTModel + +bart = BARTModel.from_pretrained( + 'checkpoints/', + checkpoint_file='checkpoint_best.pt', + data_name_or_path='RTE-bin' +) + +label_fn = lambda label: bart.task.label_dictionary.string( + [label + bart.task.label_dictionary.nspecial] +) +ncorrect, nsamples = 0, 0 +bart.cuda() +bart.eval() +with open('glue_data/RTE/dev.tsv') as fin: + fin.readline() + for index, line in enumerate(fin): + tokens = line.strip().split('\t') + sent1, sent2, target = tokens[1], tokens[2], tokens[3] + tokens = bart.encode(sent1, sent2) + prediction = bart.predict('sentence_classification_head', tokens).argmax().item() + prediction_label = label_fn(prediction) + ncorrect += int(prediction_label == target) + nsamples += 1 +print('| Accuracy: ', float(ncorrect)/float(nsamples)) +``` diff --git a/examples/bart/README.md b/examples/bart/README.md new file mode 100644 index 000000000..858616257 --- /dev/null +++ b/examples/bart/README.md @@ -0,0 +1,169 @@ +# BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension + +[https://arxiv.org/pdf/1910.13461.pdf] + +## Introduction + +BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) Results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details. + +## Pre-trained models + +Model | Description | # params | Download +---|---|---|--- +`bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz) +`bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz) + +## Results + +**[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)** +_(dev set, single model, single-task finetuning)_ + +Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B +---|---|---|---|---|---|---|---|--- +`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4 +`bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2 + +**[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)** +_(dev set, no additional data used)_ + +Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1 +---|---|--- +`roberta.large` | 88.9/94.6 | 86.5/89.4 +`bart.large` | 88.8/94.6 | 86.1/89.2 + +**[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)** +_(dev set, no additional data used)_ + +Model | R1 | R2 | RL +---|---|---|--- +`BERTSUMEXTABS` | 42.13 | 19.60 | 39.18 +`bart.large` | 44.16 | 21.28 | 40.90 + +## Example usage + +##### Load BART from torch.hub (PyTorch >= 1.1): +```python +import torch +bart = torch.hub.load('pytorch/fairseq', 'bart.large') +bart.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Load BART (for PyTorch 1.0 or custom models): +```python +# Download bart.large model +wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz +tar -xzvf bart.large.tar.gz + +# Load the model in fairseq +from fairseq.models.bart import BARTModel +bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt') +bart.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Apply Byte-Pair Encoding (BPE) to input text: +```python +tokens = bart.encode('Hello world!') +assert tokens.tolist() == [0, 31414, 232, 328, 2] +bart.decode(tokens) # 'Hello world!' +``` + +##### Extract features from BART: +```python +# Extract the last layer's features +last_layer_features = bart.extract_features(tokens) +assert last_layer_features.size() == torch.Size([1, 5, 1024]) + +# Extract all layer's features from decoder (layer 0 is the embedding layer) +all_layers = bart.extract_features(tokens, return_all_hiddens=True) +assert len(all_layers) == 13 +assert torch.all(all_layers[-1] == last_layer_features) +``` + +##### Use BART for sentence-pair classification tasks: +```python +# Download BART already finetuned for MNLI +bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli') +bart.eval() # disable dropout for evaluation + +# Encode a pair of sentences and make a prediction +tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.') +bart.predict('mnli', tokens).argmax() # 0: contradiction + +# Encode another pair of sentences +tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.') +bart.predict('mnli', tokens).argmax() # 2: entailment +``` + +##### Register a new (randomly initialized) classification head: +```python +bart.register_classification_head('new_task', num_classes=3) +logprobs = bart.predict('new_task', tokens) +``` + +##### Batched prediction: +```python +import torch +from fairseq.data.data_utils import collate_tokens + +bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli') +bart.eval() + +batch_of_pairs = [ + ['BART is a seq2seq model.', 'BART is not sequence to sequence.'], + ['BART is denoising autoencoder.', 'BART is version of autoencoder.'], +] + +batch = collate_tokens( + [bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1 +) + +logprobs = bart.predict('mnli', batch) +print(logprobs.argmax(dim=1)) +# tensor([0, 2]) +``` + +##### Using the GPU: +```python +bart.cuda() +bart.predict('new_task', tokens) +``` + +#### Evaluating the `bart.large.mnli` model: + +Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set. +```python +label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'} +ncorrect, nsamples = 0, 0 +bart.cuda() +bart.eval() +with open('glue_data/MNLI/dev_matched.tsv') as fin: + fin.readline() + for index, line in enumerate(fin): + tokens = line.strip().split('\t') + sent1, sent2, target = tokens[8], tokens[9], tokens[-1] + tokens = bart.encode(sent1, sent2) + prediction = bart.predict('mnli', tokens).argmax().item() + prediction_label = label_map[prediction] + ncorrect += int(prediction_label == target) + nsamples += 1 + print('| Accuracy: ', float(ncorrect)/float(nsamples)) +# Expected output: 0.9010 +``` + +## Finetuning + +- [Finetuning on GLUE](README.glue.md) + +## Citation + +```bibtex +@article{lewis2019bart, + title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural +Language Generation, Translation, and Comprehension}, + author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and + Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov + and Luke Zettlemoyer }, + journal={arXiv preprint arXiv:1910.13461}, + year = {2019}, +} +``` diff --git a/examples/roberta/README.glue.md b/examples/roberta/README.glue.md index 52a974de6..d0a266b86 100644 --- a/examples/roberta/README.glue.md +++ b/examples/roberta/README.glue.md @@ -79,7 +79,7 @@ roberta = RobertaModel.from_pretrained( ) label_fn = lambda label: roberta.task.label_dictionary.string( - [label + roberta.task.target_dictionary.nspecial] + [label + roberta.task.label_dictionary.nspecial] ) ncorrect, nsamples = 0, 0 roberta.cuda() diff --git a/fairseq/data/__init__.py b/fairseq/data/__init__.py index aec8b819e..685059baa 100644 --- a/fairseq/data/__init__.py +++ b/fairseq/data/__init__.py @@ -9,11 +9,13 @@ from .fairseq_dataset import FairseqDataset from .base_wrapper_dataset import BaseWrapperDataset +from .append_token_dataset import AppendTokenDataset from .audio.raw_audio_dataset import FileAudioDataset from .backtranslation_dataset import BacktranslationDataset from .colorize_dataset import ColorizeDataset from .concat_dataset import ConcatDataset from .concat_sentences_dataset import ConcatSentencesDataset +from .denoising_dataset import DenoisingDataset from .id_dataset import IdDataset from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset from .language_pair_dataset import LanguagePairDataset @@ -33,6 +35,7 @@ from .prepend_token_dataset import PrependTokenDataset from .raw_label_dataset import RawLabelDataset from .replace_dataset import ReplaceDataset from .resampling_dataset import ResamplingDataset +from .roll_dataset import RollDataset from .round_robin_zip_datasets import RoundRobinZipDatasets from .sharded_dataset import ShardedDataset from .sort_dataset import SortDataset @@ -42,7 +45,6 @@ from .token_block_dataset import TokenBlockDataset from .transform_eos_dataset import TransformEosDataset from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset from .truncate_dataset import TruncateDataset -from .resampling_dataset import ResamplingDataset from .iterators import ( CountingIterator, @@ -52,12 +54,14 @@ from .iterators import ( ) __all__ = [ + 'AppendTokenDataset', 'BacktranslationDataset', 'BaseWrapperDataset', 'ColorizeDataset', 'ConcatDataset', 'ConcatSentencesDataset', 'CountingIterator', + 'DenoisingDataset', 'Dictionary', 'EpochBatchIterator', 'FairseqDataset', @@ -83,9 +87,10 @@ __all__ = [ 'PrependDataset', 'PrependTokenDataset', 'ReplaceDataset', + 'RollDataset', 'FileAudioDataset', 'RawLabelDataset', - 'ResamplingDataset' + 'ResamplingDataset', 'RightPadDataset', 'RoundRobinZipDatasets', 'ShardedDataset', diff --git a/fairseq/data/append_token_dataset.py b/fairseq/data/append_token_dataset.py new file mode 100644 index 000000000..7298129f6 --- /dev/null +++ b/fairseq/data/append_token_dataset.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from . import BaseWrapperDataset + + +class AppendTokenDataset(BaseWrapperDataset): + + def __init__(self, dataset, token=None): + super().__init__(dataset) + self.token = token + if token is not None: + self._sizes = np.array(dataset.sizes) + 1 + else: + self._sizes = dataset.sizes + + def __getitem__(self, idx): + item = self.dataset[idx] + if self.token is not None: + item = torch.cat([item, item.new([self.token])]) + return item + + @property + def sizes(self): + return self._sizes + + def num_tokens(self, index): + n = self.dataset.num_tokens(index) + if self.token is not None: + n += 1 + return n + + def size(self, index): + n = self.dataset.size(index) + if self.token is not None: + n += 1 + return n diff --git a/fairseq/data/denoising_dataset.py b/fairseq/data/denoising_dataset.py new file mode 100644 index 000000000..345969530 --- /dev/null +++ b/fairseq/data/denoising_dataset.py @@ -0,0 +1,386 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import math + +from . import data_utils, FairseqDataset + + +def collate( + samples, + pad_idx, + eos_idx, + vocab, + left_pad_source=False, + left_pad_target=False, + input_feeding=True, +): + assert input_feeding + if len(samples) == 0: + return {} + + def merge(key, left_pad, move_eos_to_beginning=False): + return data_utils.collate_tokens( + [s[key] for s in samples], + pad_idx, eos_idx, left_pad, move_eos_to_beginning, + ) + + id = torch.LongTensor([s['id'] for s in samples]) + src_tokens = merge('source', left_pad=left_pad_source) + # sort by descending source length + src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) + src_lengths, sort_order = src_lengths.sort(descending=True) + id = id.index_select(0, sort_order) + src_tokens = src_tokens.index_select(0, sort_order) + + prev_output_tokens = None + target = None + if samples[0].get('target', None) is not None: + target = merge('target', left_pad=left_pad_target) + target = target.index_select(0, sort_order) + ntokens = sum(len(s['target']) for s in samples) + + if input_feeding: + # we create a shifted version of targets for feeding the + # previous output token(s) into the next decoder step + prev_output_tokens = merge( + 'target', + left_pad=left_pad_target, + move_eos_to_beginning=True, + ) + prev_output_tokens = prev_output_tokens.index_select(0, sort_order) + else: + ntokens = sum(len(s['source']) for s in samples) + + batch = { + 'id': id, + 'ntokens': ntokens, + 'net_input': { + 'src_tokens': src_tokens, + 'src_lengths': src_lengths, + }, + 'target': target, + 'nsentences': samples[0]['source'].size(0), + } + if prev_output_tokens is not None: + batch['net_input']['prev_output_tokens'] = prev_output_tokens + + return batch + + +class DenoisingDataset(FairseqDataset): + """ + A wrapper around TokenBlockDataset for BART dataset. + + Args: + dataset (TokenBlockDataset): dataset to wrap + sizes (List[int]): sentence lengths + vocab (~fairseq.data.Dictionary): vocabulary + mask_idx (int): dictionary index used for masked token + mask_whole_words: only mask whole words. This should be a byte mask + over vocab indices, indicating whether it is the beginning of a + word. We will extend any mask to encompass the whole word. + shuffle (bool, optional): shuffle the elements before batching. + Default: ``True`` + seed: Seed for random number generator for reproducibility. + args: argparse arguments. + """ + + def __init__( + self, + dataset, + sizes, + vocab, + mask_idx, + mask_whole_words, + shuffle, + seed, + args + ): + self.dataset = dataset + + self.sizes = sizes + + self.vocab = vocab + self.shuffle = shuffle + self.seed = seed + self.mask_idx = mask_idx + self.mask_whole_word = mask_whole_words + self.mask_ratio = args.mask + self.random_ratio = args.mask_random + self.insert_ratio = args.insert + self.rotate_ratio = args.rotate + self.permute_sentence_ratio = args.permute_sentences + + if args.bpe != 'gpt2': + self.full_stop_index = self.vocab.index(".") + else: + assert args.bpe == 'gpt2' + self.full_stop_index = self.vocab.index('13') + + self.replace_length = args.replace_length + if not self.replace_length in [-1, 0, 1]: + raise (f'invalid arg: replace_length={self.replace_length}') + if not args.mask_length in ['subword', 'word', 'span-poisson']: + raise (f'invalid arg: mask-length={args.mask_length}') + if args.mask_length == 'subword' and not args.replace_length in [0, 1]: + raise (f'if using subwords, use replace-length=1 or 0') + + self.mask_span_distribution = None + if args.mask_length == 'span-poisson': + _lambda = args.poisson_lambda + + lambda_to_the_k = 1 + e_to_the_minus_lambda = math.exp(-_lambda) + k_factorial = 1 + ps = [] + for k in range(0, 128): + ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) + lambda_to_the_k *= _lambda + k_factorial *= (k + 1) + if ps[-1] < 0.0000001: + break + ps = torch.FloatTensor(ps) + self.mask_span_distribution = torch.distributions.Categorical(ps) + + self.verbose = args.verbose + self.epoch = 0 + + def set_epoch(self, epoch, **unused): + self.epoch = epoch + + def __getitem__(self, index): + with data_utils.numpy_seed(self.seed, self.epoch, index): + tokens = self.dataset[index] + assert tokens[-1] == self.vocab.eos() + source, target = tokens, tokens.clone() + + if self.permute_sentence_ratio > 0.0: + source = self.permute_sentences(source, self.permute_sentence_ratio) + + if self.mask_ratio > 0: + source = self.add_whole_word_mask(source, self.mask_ratio) + + if self.insert_ratio > 0: + source = self.add_insertion_noise(source, self.insert_ratio) + + if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio: + source = self.add_rolling_noise(source) + + assert (source >= 0).all() + assert (source[1:-1] >= 1).all() + assert (source <= len(self.vocab)).all() + assert source[0] == self.vocab.bos() + assert source[-1] == self.vocab.eos() + return { + 'id': index, + 'source': source, + 'target': target, + } + + def __len__(self): + return len(self.dataset) + + def permute_sentences(self, source, p=1.0): + full_stops = (source == self.full_stop_index) + # Pretend it ends with a full stop so last span is a sentence + full_stops[-2] = 1 + + # Tokens that are full stops, where the previous token is not + sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero() + 2 + result = source.clone() + + num_sentences = sentence_ends.size(0) + num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0) + substitutions = torch.randperm(num_sentences)[:num_to_permute] + ordering = torch.arange(0, num_sentences) + ordering[substitutions] = substitutions[torch.randperm(num_to_permute)] + + # Ignore at start + index = 1 + for i in ordering: + sentence = source[(sentence_ends[i - 1] if i > 0 else 1):sentence_ends[i]] + result[index:index + sentence.size(0)] = sentence + index += sentence.size(0) + return result + + def word_starts(self, source): + if self.mask_whole_word is not None: + is_word_start = self.mask_whole_word.gather(0, source) + else: + is_word_start = torch.ones(source.size()) + is_word_start[0] = 0 + is_word_start[-1] = 0 + return is_word_start + + def add_whole_word_mask(self, source, p): + is_word_start = self.word_starts(source) + num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) + num_inserts = 0 + if num_to_mask == 0: + return source + + if self.mask_span_distribution is not None: + lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) + + # Make sure we have enough to mask + cum_length = torch.cumsum(lengths, 0) + while cum_length[-1] < num_to_mask: + lengths = torch.cat([lengths, self.mask_span_distribution.sample(sample_shape=(num_to_mask,))], dim=0) + cum_length = torch.cumsum(lengths, 0) + + # Trim to masking budget + i = 0 + while cum_length[i] < num_to_mask: + i += 1 + lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) + num_to_mask = i + 1 + lengths = lengths[:num_to_mask] + + # Handle 0-length mask (inserts) separately + lengths = lengths[lengths > 0] + num_inserts = num_to_mask - lengths.size(0) + num_to_mask -= num_inserts + if num_to_mask == 0: + return self.add_insertion_noise(source, num_inserts / source.size(0)) + + assert (lengths > 0).all() + else: + lengths = torch.ones((num_to_mask,)).long() + assert is_word_start[-1] == 0 + word_starts = is_word_start.nonzero() + indices = word_starts[torch.randperm(word_starts.size(0))[:num_to_mask]].squeeze(1) + mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio + + source_length = source.size(0) + assert source_length - 1 not in indices + to_keep = torch.ones(source_length, dtype=torch.bool) + is_word_start[-1] = 255 # acts as a long length, so spans don't go over the end of doc + if self.replace_length == 0: + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),)) + + if self.mask_span_distribution is not None: + assert len(lengths.size()) == 1 + assert lengths.size() == indices.size() + lengths -= 1 + while indices.size(0) > 0: + assert lengths.size() == indices.size() + lengths -= is_word_start[indices + 1].long() + uncompleted = lengths >= 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + lengths = lengths[uncompleted] + if self.replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),)) + else: + # A bit faster when all lengths are 1 + while indices.size(0) > 0: + uncompleted = is_word_start[indices + 1] == 0 + indices = indices[uncompleted] + 1 + mask_random = mask_random[uncompleted] + if self.replace_length != -1: + # delete token + to_keep[indices] = 0 + else: + # keep index, but replace it with [MASK] + source[indices] = self.mask_idx + source[indices[mask_random]] = torch.randint(1, len(self.vocab), size=(mask_random.sum(),)) + + assert source_length - 1 not in indices + + source = source[to_keep] + + if num_inserts > 0: + source = self.add_insertion_noise(source, num_inserts / source.size(0)) + + return source + + def add_permuted_noise(self, tokens, p): + num_words = len(tokens) + num_to_permute = math.ceil(((num_words * 2) * p) / 2.0) + substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1 + tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]] + return tokens + + def add_rolling_noise(self, tokens): + offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1) + tokens = torch.cat( + (tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]), + dim=0, + ) + return tokens + + def add_insertion_noise(self, tokens, p): + if p == 0.0: + return tokens + + num_tokens = len(tokens) + n = int(math.ceil(num_tokens * p)) + + noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 + noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) + noise_mask[noise_indices] = 1 + result = torch.LongTensor(n + len(tokens)).fill_(-1) + + num_random = int(math.ceil(n * self.random_ratio)) + result[noise_indices[num_random:]] = self.mask_idx + result[noise_indices[:num_random]] = torch.randint(low=1, high=len(self.vocab), size=(num_random,)) + + result[~noise_mask] = tokens + + assert (result >= 0).all() + return result + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + Args: + samples (List[dict]): samples to collate + Returns: + dict: a mini-batch of data + """ + return collate(samples, self.vocab.pad(), self.vocab.eos(), self.vocab) + + def num_tokens(self, index): + """Return the number of tokens in a sample. This value is used to + enforce ``--max-tokens`` during batching.""" + return self.sizes[index] + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return self.sizes[index] + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + else: + indices = np.arange(len(self)) + return indices[np.argsort(self.sizes[indices], kind='mergesort')] + + def prefetch(self, indices): + self.src.prefetch(indices) + self.tgt.prefetch(indices) + + @property + def supports_prefetch(self): + return ( + hasattr(self.src, 'supports_prefetch') + and self.src.supports_prefetch + and hasattr(self.tgt, 'supports_prefetch') + and self.tgt.supports_prefetch + ) diff --git a/fairseq/data/encoders/utils.py b/fairseq/data/encoders/utils.py new file mode 100644 index 000000000..a0e491c14 --- /dev/null +++ b/fairseq/data/encoders/utils.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.data import encoders + + +def get_whole_word_mask(args, dictionary): + bpe = encoders.build_bpe(args) + if bpe is not None: + def is_beginning_of_word(i): + if i < dictionary.nspecial: + # special elements are always considered beginnings + return True + tok = dictionary[i] + if tok.startswith('madeupword'): + return True + try: + return bpe.is_beginning_of_word(tok) + except ValueError: + return True + mask_whole_words = torch.ByteTensor(list( + map(is_beginning_of_word, range(len(dictionary))) + )) + return mask_whole_words + return None diff --git a/fairseq/data/roll_dataset.py b/fairseq/data/roll_dataset.py new file mode 100644 index 000000000..d07800d0f --- /dev/null +++ b/fairseq/data/roll_dataset.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from . import BaseWrapperDataset + + +class RollDataset(BaseWrapperDataset): + + def __init__(self, dataset, shifts): + super().__init__(dataset) + self.shifts = shifts + + def __getitem__(self, index): + item = self.dataset[index] + return torch.roll(item, self.shifts) diff --git a/fairseq/models/bart/__init__.py b/fairseq/models/bart/__init__.py new file mode 100644 index 000000000..a701923f7 --- /dev/null +++ b/fairseq/models/bart/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .hub_interface import * # noqa +from .model import * # noqa diff --git a/fairseq/models/bart/hub_interface.py b/fairseq/models/bart/hub_interface.py new file mode 100644 index 000000000..6c572ecd7 --- /dev/null +++ b/fairseq/models/bart/hub_interface.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq.data import encoders + + +class BARTHubInterface(nn.Module): + """A simple PyTorch Hub interface to BART. + + Usage: https://github.com/pytorch/fairseq/tree/master/examples/BART + """ + + def __init__(self, args, task, model): + super().__init__() + self.args = args + self.task = task + self.model = model + + self.bpe = encoders.build_bpe(args) + + # this is useful for determining the device + self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) + + @property + def device(self): + return self._float_tensor.device + + def encode(self, sentence: str, *addl_sentences, no_separator=True) -> torch.LongTensor: + """ + BPE-encode a sentence (or multiple sentences). + + Every sequence begins with a beginning-of-sentence (``) symbol. + Every sentence ends with an end-of-sentence (``). + + Example (single sentence): ` a b c ` + Example (sentence pair): ` d e f 1 2 3 ` + + The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE + requires leading spaces. For example:: + + >>> bart.encode('Hello world').tolist() + [0, 31414, 232, 2] + >>> bart.encode(' world').tolist() + [0, 232, 2] + >>> bart.encode('world').tolist() + [0, 8331, 2] + """ + bpe_sentence = ' ' + self.bpe.encode(sentence) + ' ' + for s in addl_sentences: + bpe_sentence += (' ' if not no_separator else '') + bpe_sentence += ' ' + self.bpe.encode(s) + ' ' + tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False) + return tokens.long() + + def decode(self, tokens: torch.LongTensor): + assert tokens.dim() == 1 + tokens = tokens.numpy() + if tokens[0] == self.task.source_dictionary.bos(): + tokens = tokens[1:] # remove + eos_mask = (tokens == self.task.source_dictionary.eos()) + doc_mask = eos_mask[1:] & eos_mask[:-1] + sentences = np.split(tokens, doc_mask.nonzero()[0] + 1) + sentences = [self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences] + if len(sentences) == 1: + return sentences[0] + return sentences + + def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor: + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + if tokens.size(-1) > min(self.model.max_positions()): + raise ValueError('tokens exceeds maximum length: {} > {}'.format( + tokens.size(-1), self.model.max_positions() + )) + tokens.to(device=self.device), + prev_output_tokens = tokens.clone() + prev_output_tokens[:, 0] = tokens[:, -1] + prev_output_tokens[:, 1:] = tokens[:, :-1] + features, extra = self.model( + src_tokens=tokens, + src_lengths=None, + prev_output_tokens=prev_output_tokens, + features_only=True, + return_all_hiddens=return_all_hiddens, + ) + if return_all_hiddens: + # convert from T x B x C -> B x T x C + inner_states = extra['inner_states'] + return [inner_state.transpose(0, 1) for inner_state in inner_states] + else: + return features # just the last layer's features + + def register_classification_head( + self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs + ): + self.model.register_classification_head( + name, num_classes=num_classes, embedding_size=embedding_size, **kwargs + ) + + def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False): + if tokens.dim() == 1: + tokens = tokens.unsqueeze(0) + features = self.extract_features(tokens.to(device=self.device)) + sentence_representation = features[ + tokens.eq(self.task.source_dictionary.eos()), : + ].view(features.size(0), -1, features.size(-1))[:, -1, :] + + logits = self.model.classification_heads[head](sentence_representation) + if return_logits: + return logits + return F.log_softmax(logits, dim=-1) diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py new file mode 100644 index 000000000..6ab3a62e5 --- /dev/null +++ b/fairseq/models/bart/model.py @@ -0,0 +1,246 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +BART: Denoising Sequence-to-Sequence Pre-training for +Natural Language Generation, Translation, and Comprehension +""" + +import torch.nn as nn + +from fairseq import utils +from fairseq.models import ( + register_model, + register_model_architecture, +) +from fairseq.models.transformer import TransformerModel +from fairseq.modules.transformer_sentence_encoder import init_bert_params + +from .hub_interface import BARTHubInterface + + +@register_model('bart') +class BARTModel(TransformerModel): + + @classmethod + def hub_models(cls): + return { + 'bart.large': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz', + 'bart.large.mnli': 'http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz', + } + + def __init__(self, args, encoder, decoder): + super().__init__(args, encoder, decoder) + + # We follow BERT's random weight initialization + self.apply(init_bert_params) + + self.classification_heads = nn.ModuleDict() + + @staticmethod + def add_args(parser): + super(BARTModel, BARTModel).add_args(parser) + parser.add_argument( + '--max-source-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the source sequence' + ) + parser.add_argument( + '--max-target-positions', default=1024, type=int, metavar='N', + help='max number of tokens in the target sequence' + ) + parser.add_argument( + '--pooler-dropout', type=float, metavar='D', + help='dropout probability in the masked_lm pooler layers' + ) + parser.add_argument( + '--pooler-activation-fn', + choices=utils.get_available_activation_fns(), + help='activation function to use for pooler layer' + ) + + @property + def supported_targets(self): + return {'self'} + + def forward( + self, src_tokens, src_lengths, prev_output_tokens, + features_only=False, classification_head_name=None, **kwargs + ): + if classification_head_name is not None: + features_only = True + + encoder_out = self.encoder( + src_tokens, + src_lengths=src_lengths, + **kwargs, + ) + x, extra = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + features_only=features_only, + **kwargs, + ) + + if classification_head_name is not None: + sentence_representation = x[ + src_tokens.eq(self.encoder.dictionary.eos()), : + ].view(x.size(0), -1, x.size(-1))[:, -1, :] + x = self.classification_heads[classification_head_name]( + sentence_representation + ) + return x, extra + + @classmethod + def from_pretrained( + cls, + model_name_or_path, + checkpoint_file='model.pt', + data_name_or_path='.', + bpe='gpt2', + **kwargs, + ): + from fairseq import hub_utils + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + bpe=bpe, + load_checkpoint_heads=True, + **kwargs, + ) + return BARTHubInterface(x['args'], x['task'], x['models'][0]) + + def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): + """Register a classification head.""" + print("Registering classification head: {0}".format(name)) + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + print( + 'WARNING: re-registering head "{}" with num_classes {} (prev: {}) ' + 'and inner_dim {} (prev: {})'.format( + name, num_classes, prev_num_classes, inner_dim, prev_inner_dim + ) + ) + self.classification_heads[name] = BARTClassificationHead( + self.args.encoder_embed_dim, + inner_dim or self.args.encoder_embed_dim, + num_classes, + self.args.pooler_activation_fn, + self.args.pooler_dropout, + ) + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + + prefix = name + '.' if name != '' else '' + current_head_names = [] if not hasattr(self, 'classification_heads') else \ + self.classification_heads.keys() + + # Handle new classification heads present in the state dict. + keys_to_delete = [] + for k in state_dict.keys(): + if not k.startswith(prefix + 'classification_heads.'): + continue + + head_name = k[len(prefix + 'classification_heads.'):].split('.')[0] + num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0) + inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0) + + if getattr(self.args, 'load_checkpoint_heads', False): + if head_name not in current_head_names: + self.register_classification_head(head_name, num_classes, inner_dim) + else: + if head_name not in current_head_names: + print( + 'WARNING: deleting classification head ({}) from checkpoint ' + 'not present in current model: {}'.format(head_name, k) + ) + keys_to_delete.append(k) + elif ( + num_classes != self.classification_heads[head_name].out_proj.out_features + or inner_dim != self.classification_heads[head_name].dense.out_features + ): + print( + 'WARNING: deleting classification head ({}) from checkpoint ' + 'with different dimensions than current model: {}'.format(head_name, k) + ) + keys_to_delete.append(k) + for k in keys_to_delete: + del state_dict[k] + + # Copy any newly-added classification heads into the state dict + # with their current weights. + if hasattr(self, 'classification_heads'): + cur_state = self.classification_heads.state_dict() + for k, v in cur_state.items(): + if prefix + 'classification_heads.' + k not in state_dict: + print('Overwriting', prefix + 'classification_heads.' + k) + state_dict[prefix + 'classification_heads.' + k] = v + + +class BARTClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim, + inner_dim, + num_classes, + activation_fn, + pooler_dropout, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, features, **kwargs): + x = features + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@register_model_architecture('bart', 'bart_large') +def bart_large_architecture(args): + args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) + args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) + args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4*1024) + args.encoder_layers = getattr(args, 'encoder_layers', 12) + args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16) + args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False) + args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', True) + args.decoder_embed_path = getattr(args, 'decoder_embed_path', None) + args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim) + args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) + args.decoder_layers = getattr(args, 'decoder_layers', 12) + args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) + args.decoder_normalize_before = getattr(args, 'decoder_normalize_before', False) + args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', True) + args.attention_dropout = getattr(args, 'attention_dropout', 0.) + args.relu_dropout = getattr(args, 'relu_dropout', 0.) + args.dropout = getattr(args, 'dropout', 0.1) + args.max_target_positions = getattr(args, 'max_target_positions', 1024) + args.max_source_positions = getattr(args, 'max_source_positions', 1024) + args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) + args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) + args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', True) + args.share_all_embeddings = getattr(args, 'share_all_embeddings', True) + + args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) + args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) + + args.no_scale_embedding = getattr(args, 'no_scale_embedding', True) + args.layernorm_embedding = getattr(args, 'layernorm_embedding', True) + + args.activation_fn = getattr(args, 'activation_fn', 'gelu') + args.pooler_activation_fn = getattr(args, 'pooler_activation_fn', 'tanh') + args.pooler_dropout = getattr(args, 'pooler_dropout', 0.0) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 16c8ec5bc..3c2b607a5 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -67,8 +67,9 @@ class TransformerModel(FairseqEncoderDecoderModel): } # fmt: on - def __init__(self, encoder, decoder): + def __init__(self, args, encoder, decoder): super().__init__(encoder, decoder) + self.args = args self.supports_align_args = True @staticmethod @@ -140,6 +141,10 @@ class TransformerModel(FairseqEncoderDecoderModel): help='which layers to *keep* when pruning as a comma-separated list') parser.add_argument('--decoder-layers-to-keep', default=None, help='which layers to *keep* when pruning as a comma-separated list') + parser.add_argument('--layernorm-embedding', action='store_true', + help='add layernorm to embedding') + parser.add_argument('--no-scale-embedding', action='store_true', + help='if True, dont scale embeddings') # fmt: on @classmethod @@ -195,7 +200,7 @@ class TransformerModel(FairseqEncoderDecoderModel): encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) - return cls(encoder, decoder) + return cls(args, encoder, decoder) @classmethod def build_encoder(cls, args, src_dict, embed_tokens): @@ -219,7 +224,7 @@ class TransformerAlignModel(TransformerModel): """ def __init__(self, encoder, decoder, args): - super().__init__(encoder, decoder) + super().__init__(args, encoder, decoder) self.alignment_heads = args.alignment_heads self.alignment_layer = args.alignment_layer self.full_context_alignment = args.full_context_alignment @@ -297,7 +302,9 @@ class TransformerEncoder(FairseqEncoder): self.max_source_positions = args.max_source_positions self.embed_tokens = embed_tokens - self.embed_scale = math.sqrt(embed_dim) + + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) + self.embed_positions = PositionalEmbedding( args.max_source_positions, embed_dim, self.padding_idx, learned=args.encoder_learned_pos, @@ -315,17 +322,22 @@ class TransformerEncoder(FairseqEncoder): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None - + if getattr(args, 'layernorm_embedding', False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None def forward_embedding(self, src_tokens): # embed tokens and positions embed = self.embed_scale * self.embed_tokens(src_tokens) if self.embed_positions is not None: x = embed + self.embed_positions(src_tokens) + if self.layernorm_embedding: + x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) return x, embed - def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False): + def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False, **unused): """ Args: src_tokens (LongTensor): tokens in the source language of shape @@ -422,6 +434,7 @@ class TransformerEncoder(FairseqEncoder): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): weights_key = '{}.embed_positions.weights'.format(name) if weights_key in state_dict: + print('deleting {0}'.format(weights_key)) del state_dict[weights_key] state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) for i in range(len(self.layers)): @@ -466,7 +479,8 @@ class TransformerDecoder(FairseqIncrementalDecoder): self.max_target_positions = args.max_target_positions self.embed_tokens = embed_tokens - self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + + self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim) self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None @@ -507,6 +521,10 @@ class TransformerDecoder(FairseqIncrementalDecoder): self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None + if getattr(args, 'layernorm_embedding', False): + self.layernorm_embedding = LayerNorm(embed_dim) + else: + self.layernorm_embedding = None def forward( self, @@ -590,6 +608,10 @@ class TransformerDecoder(FairseqIncrementalDecoder): if positions is not None: x += positions + + if self.layernorm_embedding: + x = self.layernorm_embedding(x) + x = F.dropout(x, p=self.dropout, training=self.training) # B x T x C -> T x B x C @@ -758,6 +780,9 @@ def base_architecture(args): args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) + args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) + args.layernorm_embedding = getattr(args, 'layernorm_embedding', False) + @register_model_architecture('transformer', 'transformer_iwslt_de_en') def transformer_iwslt_de_en(args): diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index f04dd3603..59ed877ef 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -103,6 +103,10 @@ class TransformerLanguageModel(FairseqLanguageModel): help='LayerDrop probability for decoder') parser.add_argument('--decoder-layers-to-keep', default=None, help='which layers to *keep* when pruning as a comma-separated list') + parser.add_argument('--layernorm-embedding', action='store_true', + help='add layernorm to embedding') + parser.add_argument('--no-scale-embedding', action='store_true', + help='if True, dont scale embeddings') # fmt: on @classmethod @@ -190,6 +194,9 @@ def base_lm_architecture(args): args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False) args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False) + args.no_scale_embedding = getattr(args, 'no_scale_embedding', False) + args.layernorm_embedding = getattr(args, 'layernorm_embedding', False) + @register_model_architecture('transformer_lm', 'transformer_lm_big') def transformer_lm_big(args): diff --git a/fairseq/tasks/denoising.py b/fairseq/tasks/denoising.py new file mode 100644 index 000000000..736589c0a --- /dev/null +++ b/fairseq/tasks/denoising.py @@ -0,0 +1,162 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from fairseq.data import ( + data_utils, + Dictionary, + AppendTokenDataset, + DenoisingDataset, + PrependTokenDataset, + StripTokenDataset, + TokenBlockDataset, +) + +from fairseq.data.encoders.utils import get_whole_word_mask +from . import FairseqTask, register_task + + +@register_task('denoising') +class DenoisingTask(FairseqTask): + """ + Denoising task for applying sequence to sequence denoising. (ie. BART) + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument('data', help='path to data directory') + parser.add_argument('--tokens-per-sample', default=512, type=int, + help='max number of total tokens over all segments' + ' per sample for dataset') + parser.add_argument('--raw-text', default=False, action='store_true', + help='load raw text dataset') + parser.add_argument( + '--sample-break-mode', default="complete_doc", type=str, + help='mode for breaking sentence', + ) + parser.add_argument( + '--mask', default=0.0, type=float, + help='fraction of words/subwords that will be masked', + ) + parser.add_argument( + '--mask-random', default=0.0, type=float, + help='instead of using [MASK], use random token this often' + ) + parser.add_argument( + '--insert', default=0.0, type=float, + help='insert this percentage of additional random tokens', + ) + parser.add_argument( + '--permute', default=0.0, type=float, + help='take this proportion of subwords and permute them', + ) + parser.add_argument( + '--rotate', default=0.5, type=float, + help='rotate this proportion of inputs', + ) + parser.add_argument( + '--poisson-lambda', default=3.0, type=float, + help='randomly shuffle sentences for this proportion of inputs' + ) + parser.add_argument( + '--permute-sentences', default=0.0, type=float, + help='shuffle this proportion of sentences in all inputs' + ) + parser.add_argument( + '--mask-length', default="subword", type=str, + choices=['subword', 'word', 'span-possion'], + help='mask length to choose' + ) + parser.add_argument( + '--replace-length', default=-1, type=int, + help='when masking N tokens, replace with 0, 1, or N tokens (use -1 for N)' + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + + # add mask token + self.mask_idx = self.dictionary.add_symbol('') + + @classmethod + def setup_task(cls, args, **kwargs): + """Setup the task. + """ + dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) + print('| dictionary: {} types'.format(len(dictionary))) + if not hasattr(args, 'shuffle_instance'): + args.shuffle_instance = False + return cls(args, dictionary) + + def load_dataset(self, split, epoch=0, combine=False): + """Load a given dataset split. + + Args: + split (str): name of the split (e.g., train, valid, test) + """ + + paths = self.args.data.split(':') + assert len(paths) > 0 + data_path = paths[epoch % len(paths)] + split_path = os.path.join(data_path, split) + + dataset = data_utils.load_indexed_dataset( + split_path, + self.dictionary, + self.args.dataset_impl, + combine=combine, + ) + if dataset is None: + raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path)) + + dataset = StripTokenDataset(dataset, self.dictionary.eos()) + + # create continuous blocks of tokens + dataset = TokenBlockDataset( + dataset, + dataset.sizes, + self.args.tokens_per_sample - 2, # one less for and one for + pad=self.dictionary.pad(), + eos=self.dictionary.eos(), + break_mode=self.args.sample_break_mode, + document_sep_len=0 + ) + + # prepend beginning-of-sentence token (, equiv. to [CLS] in BERT) + dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) + dataset = AppendTokenDataset(dataset, self.source_dictionary.eos()) + + mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ + if self.args.mask_length != 'subword' else None + + self.datasets[split] = DenoisingDataset( + dataset, dataset.sizes, self.dictionary, self.mask_idx, + mask_whole_words, shuffle=self.args.shuffle_instance, + seed=self.seed, args=self.args + ) + print( + "| Split: {0}, Loaded {1} samples of denoising_dataset".format( + split, + len(self.datasets[split]), + ) + ) + + def max_positions(self): + """Return the max sentence length allowed by the task.""" + return (self.args.max_source_positions, self.args.max_target_positions) + + @property + def source_dictionary(self): + """Return the source :class:`~fairseq.data.Dictionary`.""" + return self.dictionary + + @property + def target_dictionary(self): + """Return the target :class:`~fairseq.data.Dictionary`.""" + return self.dictionary diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index 538532b20..5438a2f07 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -7,7 +7,12 @@ import numpy as np import torch from fairseq import tokenizer -from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary +from fairseq.data import ( + data_utils, + FairseqDataset, + iterators, + Dictionary, +) class FairseqTask(object): diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index fbc308fd1..091a690f9 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -6,12 +6,10 @@ import os import numpy as np -import torch from fairseq.data import ( data_utils, Dictionary, - encoders, IdDataset, MaskTokensDataset, NestedDictionaryDataset, @@ -23,6 +21,7 @@ from fairseq.data import ( TokenBlockDataset, ) from fairseq.tasks import FairseqTask, register_task +from fairseq.data.encoders.utils import get_whole_word_mask @register_task('masked_lm') @@ -106,27 +105,8 @@ class MaskedLMTask(FairseqTask): dataset = PrependTokenDataset(dataset, self.source_dictionary.bos()) # create masked input and targets - if self.args.mask_whole_words: - bpe = encoders.build_bpe(self.args) - assert bpe is not None - - def is_beginning_of_word(i): - if i < self.source_dictionary.nspecial: - # special elements are always considered beginnings - return True - tok = self.source_dictionary[i] - if tok.startswith('madeupword'): - return True - try: - return bpe.is_beginning_of_word(tok) - except ValueError: - return True - - mask_whole_words = torch.ByteTensor(list( - map(is_beginning_of_word, range(len(self.source_dictionary))) - )) - else: - mask_whole_words = None + mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \ + if self.args.mask_whole_words else None src_dataset, tgt_dataset = MaskTokensDataset.apply_mask( dataset, diff --git a/fairseq/tasks/sentence_prediction.py b/fairseq/tasks/sentence_prediction.py index 1454978eb..49bb35c47 100644 --- a/fairseq/tasks/sentence_prediction.py +++ b/fairseq/tasks/sentence_prediction.py @@ -19,6 +19,7 @@ from fairseq.data import ( PrependTokenDataset, RawLabelDataset, RightPadDataset, + RollDataset, SortDataset, StripTokenDataset, TruncateDataset, @@ -51,11 +52,21 @@ class SentencePredictionTask(FairseqTask): parser.add_argument('--no-shuffle', action='store_true', default=False) parser.add_argument('--truncate-sequence', action='store_true', default=False, help='Truncate sequence to max_sequence_length') + parser.add_argument('--add-prev-output-tokens', action='store_true', default=False, + help='Add prev_output_tokens to sample, used for encoder-decoder arch') def __init__(self, args, data_dictionary, label_dictionary): super().__init__(args) self.dictionary = data_dictionary - self.label_dictionary = label_dictionary + self._label_dictionary = label_dictionary + if not hasattr(args, 'max_positions'): + self._max_positions = ( + args.max_source_positions, + args.max_target_positions, + ) + else: + self._max_positions = args.max_positions + args.tokens_per_sample = self._max_positions @classmethod def load_dictionary(cls, args, filename, source=True): @@ -72,8 +83,6 @@ class SentencePredictionTask(FairseqTask): def setup_task(cls, args, **kwargs): assert args.num_classes > 0, 'Must set --num-classes' - args.tokens_per_sample = args.max_positions - # load data dictionary data_dict = cls.load_dictionary( args, @@ -145,6 +154,15 @@ class SentencePredictionTask(FairseqTask): 'ntokens': NumelDataset(src_tokens, reduce=True), } + if self.args.add_prev_output_tokens: + prev_tokens_dataset = RightPadDataset( + RollDataset(src_tokens, 1), + pad_idx=self.dictionary.pad(), + ) + dataset['net_input'].update( + prev_output_tokens=prev_tokens_dataset, + ) + if not self.args.regression_target: label_dataset = make_dataset('label', self.target_dictionary) if label_dataset is not None: @@ -197,7 +215,7 @@ class SentencePredictionTask(FairseqTask): return model def max_positions(self): - return self.args.max_positions + return self._max_positions @property def source_dictionary(self): @@ -205,4 +223,8 @@ class SentencePredictionTask(FairseqTask): @property def target_dictionary(self): - return self.label_dictionary + return self.dictionary + + @property + def label_dictionary(self): + return self._label_dictionary