Make torch.hub interface automatically apply tokenization and BPE

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/926

Differential Revision: D18685772

Pulled By: myleott

fbshipit-source-id: 0f99d79ed6ee72e9d3ced786d75ab9504d0dfcf0
This commit is contained in:
Myle Ott 2019-11-25 13:38:42 -08:00 committed by Facebook Github Bot
parent fb3e1e36d2
commit cb6c67bcdb
11 changed files with 111 additions and 22 deletions

View File

@ -55,8 +55,15 @@ Fairseq provides reference implementations of various sequence-to-sequence model
- mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
We also provide [pre-trained models](#pre-trained-models-and-examples) for several benchmark
translation and language modeling datasets.
We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
with a convenient `torch.hub` interface:
```python
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
en2de.translate('Hello world', beam=5)
# 'Hallo Welt'
```
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
![Model](fairseq.gif)

View File

@ -6,7 +6,7 @@ The following commands provide an example of pre-processing data, training a mod
Description | Dataset | Model | Test set(s)
---|---|---|---
Stories with Convolutional Model <br> ([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://arxiv.org/abs/1805.04833) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2)
Stories with Convolutional Model <br> ([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2)
We provide sample stories generated by the [convolutional seq2seq model](https://dl.fbaipublicfiles.com/fairseq/data/seq2seq_stories.txt) and [fusion model](https://dl.fbaipublicfiles.com/fairseq/data/fusion_stories.txt) from [Fan et al., 2018](https://arxiv.org/abs/1805.04833). The corresponding prompts for the fusion model can be found [here](https://dl.fbaipublicfiles.com/fairseq/data/fusion_prompts.txt). Note that there are unk in the file, as we modeled a small full vocabulary (no BPE or pre-training). We did not use these unk prompts for human evaluation.

View File

@ -30,6 +30,20 @@ def from_pretrained(
if data_name_or_path is not None and data_name_or_path in archive_map:
data_name_or_path = archive_map[data_name_or_path]
# allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
# for each model
if isinstance(model_name_or_path, dict):
for k, v in model_name_or_path.items():
if k == 'checkpoint_file':
checkpoint_file = v
elif (
k != 'path'
# only set kwargs that don't already have overrides
and k not in kwargs
):
kwargs[k] = v
model_name_or_path = model_name_or_path['path']
model_path = file_utils.load_archive_file(model_name_or_path)
# convenience hack for loading data and BPE codes from model archive

View File

@ -43,10 +43,18 @@ class FConvModel(FairseqEncoderDecoderModel):
@classmethod
def hub_models(cls):
def moses_subword(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'subword_nmt',
}
return {
'conv.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2',
'conv.wmt14.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2',
'conv.wmt17.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2',
'conv.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2'),
'conv.wmt14.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2'),
'conv.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2'),
}
def __init__(self, encoder, decoder):

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import math
import os
import torch
import torch.nn as nn
@ -33,7 +34,18 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel):
@classmethod
def hub_models(cls):
return {
'conv.stories': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2',
'conv.stories.pretrained': {
'path': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz',
'checkpoint_file': 'pretrained_checkpoint.pt',
'tokenizer': 'nltk',
},
'conv.stories': {
'path': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz',
'checkpoint_file': 'fusion_checkpoint.pt',
'tokenizer': 'nltk',
'pretrained': 'True',
'pretrained_checkpoint': './pretrained_checkpoint.pt',
},
# Test set containing dictionaries
'data.stories': 'https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2',
}
@ -97,6 +109,10 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel):
pretrained = eval(args.pretrained)
if pretrained:
print("| loading pretrained model")
if not os.path.exists(args.pretrained_checkpoint):
new_pretrained_checkpoint = os.path.join(args.data, args.pretrained_checkpoint)
if os.path.exists(new_pretrained_checkpoint):
args.pretrained_checkpoint = new_pretrained_checkpoint
trained_model = checkpoint_utils.load_model_ensemble(
filenames=[args.pretrained_checkpoint],
task=task,

View File

@ -53,18 +53,33 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod
def hub_models(cls):
# fmt: off
def moses_subword(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'subword_nmt',
}
def moses_fastbpe(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'fastbpe',
}
return {
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'),
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz',
'transformer.wmt19.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz',
'transformer.wmt19.en-ru': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz',
'transformer.wmt19.de-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz',
'transformer.wmt19.ru-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz',
'transformer.wmt19.en-de.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz',
'transformer.wmt19.en-ru.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz',
'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz',
'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz',
'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'),
'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'),
'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'),
'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'),
'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'),
'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'),
'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'),
'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'),
'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'),
}
# fmt: on

View File

@ -26,12 +26,20 @@ class TransformerLanguageModel(FairseqLanguageModel):
@classmethod
def hub_models(cls):
def moses_fastbpe(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'fastbpe',
}
return {
'transformer_lm.gbw.adaptive_huge': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2',
'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2',
'transformer_lm.wmt19.en': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2',
'transformer_lm.wmt19.de': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2',
'transformer_lm.wmt19.ru': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2',
'transformer_lm.wmt19.en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2'),
'transformer_lm.wmt19.de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2'),
'transformer_lm.wmt19.ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2'),
}
def __init__(self, decoder):

View File

@ -80,7 +80,7 @@ class LinearizedConvolution(ConvTBC):
kw = self.kernel_size[0]
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
assert weight.size() == (self.out_channels, kw, self.in_channels)
self._linearized_weight = weight.view(self.out_channels, -1)
self._linearized_weight = torch.nn.Parameter(weight.view(self.out_channels, -1))
return self._linearized_weight
def _clear_linearized_weight(self, *args):

View File

@ -104,6 +104,16 @@ if 'clean' in sys.argv[1:]:
subprocess.run(['rm -f fairseq/*.so fairseq/**/*.so'], shell=True)
if 'test' in sys.argv[1:]:
try:
import fairseq.data.token_block_utils_fast
except (ImportError, ModuleNotFoundError):
raise Exception(
'Please install Cython components with `python setup.py build_ext --inplace`'
'before running unit tests.'
)
setup(
name='fairseq',
version='0.8.0',

View File

@ -1,10 +1,16 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from multiprocessing import Manager
import random
import unittest
from multiprocessing import Manager
import torch
import torch.nn as nn
from fairseq import distributed_utils, optim
@ -143,3 +149,7 @@ class TestBMUF(unittest.TestCase):
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4)
if __name__ == '__main__':
unittest.main()

View File

@ -35,6 +35,7 @@ class TestMemoryEfficientFP16(unittest.TestCase):
fp16_scale_window=1,
fp16_scale_tolerance=1,
threshold_loss_scale=1,
min_loss_scale=1e-4,
),
params,
optimizer,