mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-09-11 17:25:31 +03:00
Make torch.hub interface automatically apply tokenization and BPE
Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/926 Differential Revision: D18685772 Pulled By: myleott fbshipit-source-id: 0f99d79ed6ee72e9d3ced786d75ab9504d0dfcf0
This commit is contained in:
parent
fb3e1e36d2
commit
cb6c67bcdb
11
README.md
11
README.md
@ -55,8 +55,15 @@ Fairseq provides reference implementations of various sequence-to-sequence model
|
||||
- mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
|
||||
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
|
||||
|
||||
We also provide [pre-trained models](#pre-trained-models-and-examples) for several benchmark
|
||||
translation and language modeling datasets.
|
||||
We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
|
||||
with a convenient `torch.hub` interface:
|
||||
```python
|
||||
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
|
||||
en2de.translate('Hello world', beam=5)
|
||||
# 'Hallo Welt'
|
||||
```
|
||||
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
|
||||
and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
|
||||
|
||||
![Model](fairseq.gif)
|
||||
|
||||
|
@ -6,7 +6,7 @@ The following commands provide an example of pre-processing data, training a mod
|
||||
|
||||
Description | Dataset | Model | Test set(s)
|
||||
---|---|---|---
|
||||
Stories with Convolutional Model <br> ([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://arxiv.org/abs/1805.04833) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2)
|
||||
Stories with Convolutional Model <br> ([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2)
|
||||
|
||||
We provide sample stories generated by the [convolutional seq2seq model](https://dl.fbaipublicfiles.com/fairseq/data/seq2seq_stories.txt) and [fusion model](https://dl.fbaipublicfiles.com/fairseq/data/fusion_stories.txt) from [Fan et al., 2018](https://arxiv.org/abs/1805.04833). The corresponding prompts for the fusion model can be found [here](https://dl.fbaipublicfiles.com/fairseq/data/fusion_prompts.txt). Note that there are unk in the file, as we modeled a small full vocabulary (no BPE or pre-training). We did not use these unk prompts for human evaluation.
|
||||
|
||||
|
@ -30,6 +30,20 @@ def from_pretrained(
|
||||
if data_name_or_path is not None and data_name_or_path in archive_map:
|
||||
data_name_or_path = archive_map[data_name_or_path]
|
||||
|
||||
# allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
|
||||
# for each model
|
||||
if isinstance(model_name_or_path, dict):
|
||||
for k, v in model_name_or_path.items():
|
||||
if k == 'checkpoint_file':
|
||||
checkpoint_file = v
|
||||
elif (
|
||||
k != 'path'
|
||||
# only set kwargs that don't already have overrides
|
||||
and k not in kwargs
|
||||
):
|
||||
kwargs[k] = v
|
||||
model_name_or_path = model_name_or_path['path']
|
||||
|
||||
model_path = file_utils.load_archive_file(model_name_or_path)
|
||||
|
||||
# convenience hack for loading data and BPE codes from model archive
|
||||
|
@ -43,10 +43,18 @@ class FConvModel(FairseqEncoderDecoderModel):
|
||||
|
||||
@classmethod
|
||||
def hub_models(cls):
|
||||
|
||||
def moses_subword(path):
|
||||
return {
|
||||
'conv.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2',
|
||||
'conv.wmt14.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2',
|
||||
'conv.wmt17.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2',
|
||||
'path': path,
|
||||
'tokenizer': 'moses',
|
||||
'bpe': 'subword_nmt',
|
||||
}
|
||||
|
||||
return {
|
||||
'conv.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2'),
|
||||
'conv.wmt14.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2'),
|
||||
'conv.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2'),
|
||||
}
|
||||
|
||||
def __init__(self, encoder, decoder):
|
||||
|
@ -4,6 +4,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -33,7 +34,18 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel):
|
||||
@classmethod
|
||||
def hub_models(cls):
|
||||
return {
|
||||
'conv.stories': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2',
|
||||
'conv.stories.pretrained': {
|
||||
'path': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz',
|
||||
'checkpoint_file': 'pretrained_checkpoint.pt',
|
||||
'tokenizer': 'nltk',
|
||||
},
|
||||
'conv.stories': {
|
||||
'path': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz',
|
||||
'checkpoint_file': 'fusion_checkpoint.pt',
|
||||
'tokenizer': 'nltk',
|
||||
'pretrained': 'True',
|
||||
'pretrained_checkpoint': './pretrained_checkpoint.pt',
|
||||
},
|
||||
# Test set containing dictionaries
|
||||
'data.stories': 'https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2',
|
||||
}
|
||||
@ -97,6 +109,10 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel):
|
||||
pretrained = eval(args.pretrained)
|
||||
if pretrained:
|
||||
print("| loading pretrained model")
|
||||
if not os.path.exists(args.pretrained_checkpoint):
|
||||
new_pretrained_checkpoint = os.path.join(args.data, args.pretrained_checkpoint)
|
||||
if os.path.exists(new_pretrained_checkpoint):
|
||||
args.pretrained_checkpoint = new_pretrained_checkpoint
|
||||
trained_model = checkpoint_utils.load_model_ensemble(
|
||||
filenames=[args.pretrained_checkpoint],
|
||||
task=task,
|
||||
|
@ -53,18 +53,33 @@ class TransformerModel(FairseqEncoderDecoderModel):
|
||||
@classmethod
|
||||
def hub_models(cls):
|
||||
# fmt: off
|
||||
|
||||
def moses_subword(path):
|
||||
return {
|
||||
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
|
||||
'path': path,
|
||||
'tokenizer': 'moses',
|
||||
'bpe': 'subword_nmt',
|
||||
}
|
||||
|
||||
def moses_fastbpe(path):
|
||||
return {
|
||||
'path': path,
|
||||
'tokenizer': 'moses',
|
||||
'bpe': 'fastbpe',
|
||||
}
|
||||
|
||||
return {
|
||||
'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'),
|
||||
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
|
||||
'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz',
|
||||
'transformer.wmt19.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz',
|
||||
'transformer.wmt19.en-ru': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz',
|
||||
'transformer.wmt19.de-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz',
|
||||
'transformer.wmt19.ru-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz',
|
||||
'transformer.wmt19.en-de.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz',
|
||||
'transformer.wmt19.en-ru.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz',
|
||||
'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz',
|
||||
'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz',
|
||||
'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'),
|
||||
'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'),
|
||||
'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'),
|
||||
'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'),
|
||||
'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'),
|
||||
'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'),
|
||||
'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'),
|
||||
'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'),
|
||||
'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'),
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
@ -26,12 +26,20 @@ class TransformerLanguageModel(FairseqLanguageModel):
|
||||
|
||||
@classmethod
|
||||
def hub_models(cls):
|
||||
|
||||
def moses_fastbpe(path):
|
||||
return {
|
||||
'path': path,
|
||||
'tokenizer': 'moses',
|
||||
'bpe': 'fastbpe',
|
||||
}
|
||||
|
||||
return {
|
||||
'transformer_lm.gbw.adaptive_huge': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2',
|
||||
'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2',
|
||||
'transformer_lm.wmt19.en': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2',
|
||||
'transformer_lm.wmt19.de': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2',
|
||||
'transformer_lm.wmt19.ru': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2',
|
||||
'transformer_lm.wmt19.en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2'),
|
||||
'transformer_lm.wmt19.de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2'),
|
||||
'transformer_lm.wmt19.ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2'),
|
||||
}
|
||||
|
||||
def __init__(self, decoder):
|
||||
|
@ -80,7 +80,7 @@ class LinearizedConvolution(ConvTBC):
|
||||
kw = self.kernel_size[0]
|
||||
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
|
||||
assert weight.size() == (self.out_channels, kw, self.in_channels)
|
||||
self._linearized_weight = weight.view(self.out_channels, -1)
|
||||
self._linearized_weight = torch.nn.Parameter(weight.view(self.out_channels, -1))
|
||||
return self._linearized_weight
|
||||
|
||||
def _clear_linearized_weight(self, *args):
|
||||
|
10
setup.py
10
setup.py
@ -104,6 +104,16 @@ if 'clean' in sys.argv[1:]:
|
||||
subprocess.run(['rm -f fairseq/*.so fairseq/**/*.so'], shell=True)
|
||||
|
||||
|
||||
if 'test' in sys.argv[1:]:
|
||||
try:
|
||||
import fairseq.data.token_block_utils_fast
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise Exception(
|
||||
'Please install Cython components with `python setup.py build_ext --inplace`'
|
||||
'before running unit tests.'
|
||||
)
|
||||
|
||||
|
||||
setup(
|
||||
name='fairseq',
|
||||
version='0.8.0',
|
||||
|
@ -1,10 +1,16 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from multiprocessing import Manager
|
||||
import random
|
||||
import unittest
|
||||
from multiprocessing import Manager
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from fairseq import distributed_utils, optim
|
||||
|
||||
|
||||
@ -143,3 +149,7 @@ class TestBMUF(unittest.TestCase):
|
||||
def assertAlmostEqual(self, t1, t2):
|
||||
self.assertEqual(t1.size(), t2.size(), "size mismatch")
|
||||
self.assertLess((t1 - t2).abs().max(), 1e-4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -35,6 +35,7 @@ class TestMemoryEfficientFP16(unittest.TestCase):
|
||||
fp16_scale_window=1,
|
||||
fp16_scale_tolerance=1,
|
||||
threshold_loss_scale=1,
|
||||
min_loss_scale=1e-4,
|
||||
),
|
||||
params,
|
||||
optimizer,
|
||||
|
Loading…
Reference in New Issue
Block a user