Add language models from Baevski & Auli (2018)

This commit is contained in:
Alexei Baevski 2019-03-04 11:16:24 -08:00 committed by Myle Ott
parent 02f2734ec5
commit 998ba4fb9f
15 changed files with 288 additions and 83 deletions

View File

@ -5,7 +5,7 @@ developers to train custom models for translation, summarization, language
modeling and other text generation tasks. It provides reference implementations
of various sequence-to-sequence models, including:
- **Convolutional Neural Networks (CNN)**
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/language_model/conv_lm/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Edunov et al. (2018): Classical Structured Prediction Losses for Sequence to Sequence Learning](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
@ -18,7 +18,8 @@ of various sequence-to-sequence models, including:
- [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Edunov et al. (2018): Understanding Back-Translation at Scale](examples/backtranslation/README.md)
- **_New_** [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
- **_New_** [Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](examples/language_model/transformer_lm/README.md)
- **_New_** [Shen et al. (2019): Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md)
Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
@ -88,7 +89,7 @@ We also have more detailed READMEs to reproduce results from specific papers:
- [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md)
- [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md)
- [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/conv_lm/README.md)
- [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/language_model/conv_lm/README.md)
# Join the fairseq community

View File

@ -14,6 +14,7 @@ import numpy as np
import torch
from fairseq import options, progress_bar, tasks, utils
from fairseq.data import LMContextWindowDataset
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
from fairseq.utils import import_user_module
@ -65,11 +66,22 @@ def main(parsed_args):
for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
setattr(args, arg, getattr(parsed_args, arg))
# reduce tokens per sample by the required context window size
args.tokens_per_sample -= args.context_window
task = tasks.setup_task(args)
# Load dataset splits
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
dataset = task.dataset(args.gen_subset)
if args.context_window > 0:
dataset = LMContextWindowDataset(
dataset=dataset,
tokens_per_sample=args.tokens_per_sample,
context_window=args.context_window,
pad_idx=task.source_dictionary.pad(),
)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models:
@ -84,7 +96,7 @@ def main(parsed_args):
print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
dataset=dataset,
max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(*[
@ -97,7 +109,7 @@ def main(parsed_args):
).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter()
scorer = SequenceScorer(task.target_dictionary)
scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
score_sum = 0.
count = 0
@ -107,7 +119,11 @@ def main(parsed_args):
raise NotImplementedError
else:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
bpe_toks = set(
i
for i in range(len(task.source_dictionary))
if task.source_dictionary[i].endswith(bpe_cont)
)
bpe_len = len(bpe_cont)
else:
bpe_toks = None
@ -117,23 +133,28 @@ def main(parsed_args):
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if 'net_input' not in sample:
continue
sample = utils.move_to_cuda(sample) if use_cuda else sample
gen_timer.start()
hypos = scorer.generate(models, sample)
gen_timer.stop(sample['ntokens'])
for hypos_i in hypos:
hypo = hypos_i[0]
pos_scores = hypo['positional_scores']
tokens = hypo['tokens']
tgt_len = tokens.numel()
pos_scores = hypo['positional_scores'].float()
skipped_toks = 0
if bpe_toks is not None:
for i in range(len(hypo['tokens']) - 1):
if hypo['tokens'][i].item() in bpe_toks:
for i in range(tgt_len - 1):
if tokens[i].item() in bpe_toks:
skipped_toks += 1
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0
@ -141,7 +162,7 @@ def main(parsed_args):
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
task.target_dictionary.string(tokens[inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum().cpu()
count += pos_scores.numel() - skipped_toks
@ -150,9 +171,9 @@ def main(parsed_args):
w = ''
word_prob = []
is_bpe = False
for i in range(len(hypo['tokens'])):
w_ind = hypo['tokens'][i].item()
w += task.dictionary[w_ind]
for i in range(len(tokens)):
w_ind = tokens[i].item()
w += task.source_dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
is_bpe = True
@ -161,7 +182,7 @@ def main(parsed_args):
next_prob = None
ind = i + 1
while ind < len(hypo['tokens']):
while ind < len(tokens):
if pos_scores[ind].item() != 0:
next_prob = pos_scores[ind]
break

1
examples/.gitignore vendored
View File

@ -1,3 +1,2 @@
*/*
!*/*.sh
!*/*.md

View File

@ -1,26 +0,0 @@
# Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/gbw_test_lm.tar.bz2)
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wiki103_test_lm.tar.bz2)
## Example usage
See the [language modeling README](../language_model/README.md) for instructions on reproducing results for WikiText-103
using the `fconv_lm_dauphin_wikitext103` model architecture.
## Citation
```bibtex
@inproceedings{dauphin2017language,
title={Language Modeling with Gated Convolutional Networks},
author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David},
booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70},
pages={933--941},
year={2017},
organization={JMLR}
}
```

View File

@ -2,10 +2,10 @@
## Pre-trained models
Description | Dataset | Model | Test set(s)
---|---|---|---
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/gbw_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/gbw_test_lm.tar.bz2)
Convolutional <br> ([Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)) | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wiki103_fconv_lm.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wiki103_test_lm.tar.bz2)
Description | Parameters | Dataset | Model and Test set(s)
---|---:|---|---
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2)
## Example usage
@ -16,6 +16,8 @@ These scripts provide an example of pre-processing data for the Language Modelin
Provides an example of pre-processing for [WikiText-103 language modeling task](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/):
Example usage:
Prepare data:
```
$ cd examples/language_model/
$ bash prepare-wikitext-103.sh
@ -27,17 +29,39 @@ $ TEXT=examples/language_model/wikitext-103
$ fairseq-preprocess --only-source \
--trainpref $TEXT/wiki.train.tokens --validpref $TEXT/wiki.valid.tokens --testpref $TEXT/wiki.test.tokens \
--destdir data-bin/wikitext-103
```
# Train the model:
# If it runs out of memory, try to reduce max-tokens and max-target-positions
$ mkdir -p checkpoints/wikitext-103
Train a transformer language model with adaptive inputs ([Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](transformer_lm/README.md)):
```
# If it runs out of memory, try to reduce max-tokens and tokens-per-sample
$ mkdir -p checkpoints/transformer_wikitext-103
$ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/transformer_wikitext-103 --arch transformer_lm_wiki103 \
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
--criterion adaptive_loss --max-tokens 3072 --update-freq 4 --tokens-per-sample 3072 --seed 1 \
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
# Evaluate:
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/transformer_wiki103/checkpoint_best.pt' \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 --softmax-batch 1024
```
Train a convolutional language model ([Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](conv_lm/README.md)):
```
# If it runs out of memory, try to reduce max-tokens and tokens-per-sample
$ mkdir -p checkpoints/fconv_wikitext-103
$ fairseq-train --task language_modeling data-bin/wikitext-103 \
--save-dir checkpoints/fconv_wikitext-103 \
--max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
--clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \
--adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024
--ddp-backend=no_c10d
# Evaluate:
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt'
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/fconv_wiki103/checkpoint_best.pt'
```

View File

@ -0,0 +1,19 @@
# Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)
## Example usage
See the [language modeling README](../README.md) for instructions on reproducing results for WikiText-103
using the `fconv_lm_dauphin_wikitext103` model architecture.
## Citation
```bibtex
@inproceedings{dauphin2017language,
title={Language Modeling with Gated Convolutional Networks},
author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David},
booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70},
pages={933--941},
year={2017},
organization={JMLR}
}
```

14
examples/language_model/prepare-wikitext-103.sh Executable file → Normal file
View File

@ -21,13 +21,13 @@ for ((i=0;i<${#URLS[@]};++i)); do
echo "$url not successfully downloaded."
exit -1
fi
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
elif [ ${file: -4} == ".zip" ]; then
unzip $file
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
elif [ ${file: -4} == ".zip" ]; then
unzip $file
fi
fi
done
cd ..

View File

@ -0,0 +1,26 @@
# Adaptive Input Representations for Neural Language Modeling (Baevski and Auli; 2018)
## Pre-trained models
Description | Parameters | Dataset | Model and Test set(s)
---|---:|---|---
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2)
## Example usage
See the [language modeling README](../language_model/README.md) for instructions on reproducing results for WikiText-103
using the `transformer_lm_wiki103` model architecture.
## Citation
```bibtex
@inproceedings{
baevski2018adaptive,
title={Adaptive Input Representations for Neural Language Modeling},
author={Alexei Baevski and Michael Auli},
booktitle={International Conference on Learning Representations},
year={2019},
url={https://openreview.net/forum?id=ByxZX20qFQ},
}
```

View File

@ -11,6 +11,7 @@ from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .token_block_dataset import TokenBlockDataset
@ -35,6 +36,7 @@ __all__ = [
'IndexedDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'LMContextWindowDataset',
'MonolingualDataset',
'RoundRobinZipDatasets',
'ShardedIterator',

View File

@ -0,0 +1,83 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from fairseq.data.monolingual_dataset import MonolingualDataset
from . import FairseqDataset
class LMContextWindowDataset(FairseqDataset):
"""Wraps a MonolingualDataset and provides more context for evaluation."""
def __init__(self, dataset, tokens_per_sample, context_window, pad_idx):
assert isinstance(dataset, MonolingualDataset)
assert context_window > 0
self.dataset = dataset
self.tokens_per_sample = tokens_per_sample
self.context_window = context_window
self.pad_idx = pad_idx
self.prev_tokens = np.empty([0])
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
sample = self.dataset.collater(samples)
pad = self.pad_idx
max_sample_len = self.tokens_per_sample + self.context_window
bsz, tsz = sample['net_input']['src_tokens'].shape
start_idxs = [0] * bsz
toks = sample['net_input']['src_tokens']
lengths = sample['net_input']['src_lengths']
tgt = sample['target']
new_toks = np.empty([bsz, tsz + self.context_window], dtype=np.int64)
new_tgt = np.full([bsz, tsz + self.context_window], pad, dtype=np.int64)
sample_lens = toks.ne(pad).long().sum(dim=1).cpu()
for i in range(bsz):
sample_len = sample_lens[i]
extra = len(self.prev_tokens) + sample_len - max_sample_len
if extra > 0:
self.prev_tokens = self.prev_tokens[extra:]
pads = np.full(self.context_window - len(self.prev_tokens), pad)
new_toks[i] = np.concatenate([self.prev_tokens, toks[i].numpy(), pads])
new_tgt[i, len(self.prev_tokens):len(self.prev_tokens) + len(tgt[i])] = tgt[i]
start_idxs[i] = len(self.prev_tokens)
lengths[i] += len(self.prev_tokens)
self.prev_tokens = new_toks[i][new_toks[i] != pad][-self.context_window:]
sample['net_input']['src_tokens'] = torch.from_numpy(new_toks)
sample['target'] = torch.from_numpy(new_tgt)
sample['start_indices'] = start_idxs
return sample
def get_dummy_batch(self, *args, **kwargs):
return self.dataset.get_dummy_batch(*args, **kwargs)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
# NOTE we don't shuffle the data to retain access to the previous dataset elements
return np.arange(len(self.dataset))
@property
def supports_prefetch(self):
return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)

View File

@ -196,7 +196,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='size of character embeddings')
parser.add_argument('--char-embedder-highway-layers', type=int, metavar='N', default=2,
help='number of highway layers for character token embeddder')
parser.add_argument('--adaptive-input', default=False, action='store_true',
parser.add_argument('--adaptive-input', action='store_true',
help='if set, uses adaptive input')
parser.add_argument('--adaptive-input-factor', type=float, metavar='N',
help='adaptive input factor')
@ -811,6 +811,7 @@ def base_lm_architecture(args):
@register_model_architecture('transformer_lm', 'transformer_lm_big')
def transformer_lm_big(args):
args.decoder_layers = getattr(args, 'decoder_layers', 12)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
@ -819,7 +820,16 @@ def transformer_lm_big(args):
@register_model_architecture('transformer_lm', 'transformer_lm_wiki103')
def transformer_lm_wiki103(args):
args.decoder_layers = getattr(args, 'decoder_layers', 16)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.dropout = getattr(args, 'dropout', 0.3)
args.adaptive_input = getattr(args, 'adaptive_input', True)
args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', True)
args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', '20000,60000')
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '20000,60000')
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0.2)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1)
transformer_lm_big(args)
@ -855,6 +865,7 @@ def base_architecture(args):
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
args.adaptive_input = getattr(args, 'adaptive_input', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)

View File

@ -80,7 +80,7 @@ class AdaptiveSoftmax(nn.Module):
else:
self.head = nn.Linear(input_dim, output_dim, bias=False)
self._make_tail(True, adaptive_inputs, tie_proj)
self._make_tail(adaptive_inputs, tie_proj)
def init_weights(m):
if hasattr(m, 'weight') and not isinstance(m, TiedLinear) and not isinstance(m, TiedHeadModule):
@ -89,15 +89,11 @@ class AdaptiveSoftmax(nn.Module):
self.apply(init_weights)
self.register_buffer('version', torch.LongTensor([1]))
# versions prior to 1 had a bug that offset indices on the head by 1
self.buggy_offset = 0
def _make_tail(self, fix_exponent, adaptive_inputs=None, tie_proj=False):
extra_denom = 1 if fix_exponent else 0
def _make_tail(self, adaptive_inputs=None, tie_proj=False):
self.tail = nn.ModuleList()
for i in range(len(self.cutoff) - 1):
dim = int(self.input_dim // self.factor ** (i + extra_denom))
dim = int(self.input_dim // self.factor ** (i + 1))
tied_emb, tied_proj = adaptive_inputs.weights_for_band(i + 1) \
if adaptive_inputs is not None else (None, None)
@ -123,9 +119,7 @@ class AdaptiveSoftmax(nn.Module):
def upgrade_state_dict_named(self, state_dict, name):
version_name = name + '.version'
if version_name not in state_dict:
self.buggy_offset = 1
self._make_tail(False)
state_dict[version_name] = torch.LongTensor([1])
raise Exception('This version of the model is no longer supported')
def adapt_target(self, target):
"""
@ -141,7 +135,7 @@ class AdaptiveSoftmax(nn.Module):
for i in range(len(self.cutoff) - 1):
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
new_target[0][mask] = self.cutoff[0] + i - self.buggy_offset
new_target[0][mask] = self.cutoff[0] + i
if mask.any():
target_idxs.append(mask.nonzero().squeeze(1))
@ -194,7 +188,7 @@ class AdaptiveSoftmax(nn.Module):
head_sz = self.cutoff[0] + len(self.tail)
log_probs[:, :head_sz] = self.lsm(head_y)
tail_priors = log_probs[:, self.cutoff[0] - self.buggy_offset: head_sz - self.buggy_offset].clone()
tail_priors = log_probs[:, self.cutoff[0]: head_sz].clone()
for i in range(len(self.tail)):
start = self.cutoff[i]

View File

@ -8,6 +8,7 @@
import argparse
import torch
import sys
from fairseq.criterions import CRITERION_REGISTRY
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY
@ -140,7 +141,7 @@ def get_parser(desc, default_task='translation'):
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
help='path to save logs for tensorboard, should match --logdir '
'of running tensorboard (default: no tensorboard logging)')
'of running tensorboard (default: no tensorboard logging)')
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
@ -374,6 +375,12 @@ def add_eval_lm_args(parser):
help='if set, outputs words and their predicted log probabilities to standard output')
group.add_argument('--output-word-stats', action='store_true',
help='if set, outputs word statistics such as word count, average probability, etc')
group.add_argument('--context-window', default=0, type=int, metavar='N',
help='ensures that every evaluated token has access to a context of at least this size,'
' if possible')
group.add_argument('--softmax-batch', default=sys.maxsize, type=int, metavar='N',
help='if BxT is more than this, will batch the softmax over vocab to this amount of tokens'
' in order to fit into GPU memory')
# fmt: on

View File

@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
import torch
import sys
from fairseq import utils
@ -13,14 +14,40 @@ from fairseq import utils
class SequenceScorer(object):
"""Scores the target for a given source sentence."""
def __init__(self, tgt_dict):
def __init__(self, tgt_dict, softmax_batch=None):
self.pad = tgt_dict.pad()
self.softmax_batch = softmax_batch or sys.maxsize
assert self.softmax_batch > 0
@torch.no_grad()
def generate(self, models, sample, **kwargs):
"""Score a batch of translations."""
net_input = sample['net_input']
def batch_for_softmax(dec_out, target):
# assumes decoder_out[0] is the only thing needed (may not be correct for future models!)
first, rest = dec_out[0], dec_out[1:]
bsz, tsz, dim = first.shape
if bsz * tsz < self.softmax_batch:
yield dec_out, target, True
else:
flat = first.contiguous().view(1, -1, dim)
flat_tgt = target.contiguous().view(flat.shape[:-1])
s = 0
while s < flat.size(1):
e = s + self.softmax_batch
yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False
s = e
def gather_target_probs(probs, target):
probs = probs.gather(
dim=2,
index=target.unsqueeze(-1),
)
return probs
orig_target = sample['target']
# compute scores for each model in the ensemble
avg_probs = None
avg_attn = None
@ -29,7 +56,25 @@ class SequenceScorer(object):
decoder_out = model.forward(**net_input)
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=len(models) == 1, sample=sample)
batched = batch_for_softmax(decoder_out, orig_target)
probs, idx = None, 0
for bd, tgt, is_single in batched:
sample['target'] = tgt
curr_prob = model.get_normalized_probs(bd, log_probs=len(models) == 1, sample=sample).data
if is_single:
probs = gather_target_probs(curr_prob, orig_target)
else:
if probs is None:
probs = curr_prob.new(orig_target.numel())
step = curr_prob.size(0) * curr_prob.size(1)
end = step + idx
tgt_probs = gather_target_probs(curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt)
probs[idx:end] = tgt_probs.view(-1)
idx = end
sample['target'] = orig_target
probs = probs.view(sample['target'].shape)
if avg_probs is None:
avg_probs = probs
else:
@ -45,20 +90,19 @@ class SequenceScorer(object):
avg_probs.log_()
if avg_attn is not None:
avg_attn.div_(len(models))
avg_probs = avg_probs.gather(
dim=2,
index=sample['target'].unsqueeze(-1),
).squeeze(2)
bsz = avg_probs.size(0)
hypos = []
for i in range(avg_probs.size(0)):
start_idxs = sample['start_indices'] if 'start_indices' in sample else [0] * bsz
for i in range(bsz):
# remove padding from ref
ref = utils.strip_pad(sample['target'][i, :], self.pad) if sample['target'] is not None else None
ref = utils.strip_pad(sample['target'][i, start_idxs[i]:], self.pad) \
if sample['target'] is not None else None
tgt_len = ref.numel()
avg_probs_i = avg_probs[i][:tgt_len]
avg_probs_i = avg_probs[i][start_idxs[i]:start_idxs[i] + tgt_len]
score_i = avg_probs_i.sum() / tgt_len
if avg_attn is not None:
avg_attn_i = avg_attn[i]
avg_attn_i = avg_attn[i, start_idxs[i]:]
_, alignment = avg_attn_i.max(dim=0)
else:
avg_attn_i = alignment = None

View File

@ -223,7 +223,7 @@ class LanguageModelingTask(FairseqTask):
def source_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.output_dictionary
return self.dictionary
@property
def target_dictionary(self):