Miscellaneous fixes (#2193)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2193

Reviewed By: ngoyal2707

Differential Revision: D21748548

Pulled By: myleott

fbshipit-source-id: d9f64540b55b4d427b3da6ad04a35f7b988b049a
This commit is contained in:
Myle Ott 2020-05-28 07:23:22 -07:00 committed by Facebook GitHub Bot
parent 2f7e3f3323
commit 8e48f45aa4
11 changed files with 60 additions and 38 deletions

View File

@ -4,13 +4,13 @@
Description | Parameters | Dataset | Model and Test set(s)
---|---:|---|---
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2)
## Training an LM with adaptive inputs
First, see the general [language modeling README](../README.md) for instructions
on preprocessing the WikiText-103 data.
First, see the general [language modeling README](README.md) for instructions on
preprocessing the WikiText-103 data.
Then use the following training command to train a model with adaptive inputs
using the `transformer_lm_wiki103` model architecture:

View File

@ -2,8 +2,7 @@
## Example usage
First download and preprocess the data following the main [language modeling
README](../README.md).
First download and preprocess the data following the main [language modeling README](README.md).
Then to train a convolutional LM using the `fconv_lm_dauphin_wikitext103`
architecture:

View File

@ -5,7 +5,7 @@
Model | Description | Dataset | Download
---|---|---|---
`transformer_lm.gbw.adaptive_huge` | Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) <br> 1026M params | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
`transformer_lm.wiki103.adaptive` | Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) <br> 247M params | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2)
`transformer_lm.wiki103.adaptive` | Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) <br> 247M params | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2)
`transformer_lm.wmt19.en` | English LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.gz)
`transformer_lm.wmt19.de` | German LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.gz)
`transformer_lm.wmt19.ru` | Russian LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.gz)
@ -72,8 +72,7 @@ fairseq-preprocess \
### 2) Train a language model
Next we'll train a basic transformer language model on wikitext-103. For more
advanced examples (e.g., using [adaptive inputs](https://arxiv.org/abs/1809.10853)),
please see the [Transformer LM README](transformer_lm/README.md).
advanced usage, see the [adaptive inputs README](README.adaptive_inputs.md).
To train a basic LM (assumes 2 GPUs):
```
@ -120,5 +119,5 @@ dataset, but results in better (lower) perplexity.
## Convolutional language models
Please see the [convolutional LM README](conv_lm/README.md) for instructions to
train convolutional language models.
Please see the [convolutional LM README](README.conv.md) for instructions on
training convolutional language models.

View File

@ -26,7 +26,11 @@ Model | Description | Download
Evaluate performance of these pre-trained models:
```bash
# Example for Machine Translation
python generate.py /path/to/bped/wmt/data --path nmt_checkpoint.pt --lenpen 0.4 --batch-size 64 --remove-bpe --beam 8 --gen-subset test > wmt16_gen.txt
fairseq-generate /path/to/bped/wmt/data --path nmt_checkpoint.pt \
--beam 8 --lenpen 0.4 \
--batch-size 64 \
--remove-bpe \
--gen-subset test > wmt16_gen.txt
bash scripts/compound_split_bleu.sh wmt16_gen.txt
# prints BLEU4 = 30.17
```
@ -111,8 +115,10 @@ num. model params: 146163712
```
If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling:
```
python eval_lm.py /path/to/wikitext-103 --path '/path/to/model/checkpoint' --model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}"
```bash
fairseq-eval-lm /path/to/wikitext-103 \
--path /path/to/model/checkpoint.pt \
--model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}"
```
This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model.

View File

@ -70,7 +70,7 @@ class ASGCriterion(FairseqCriterion):
self.linseg_maximum = linseg_updates
self.linseg_message_state = "none" if hide_linseg_messages else "start"
@staticmethod
@classmethod
def build_criterion(cls, args, task):
return cls(
task,

View File

@ -81,7 +81,7 @@ class CTCCriterion(FairseqCriterion):
super().__init__(task)
self.blank_idx = task.target_dictionary.index("<ctc_blank>")
@staticmethod
@classmethod
def build_criterion(cls, args, task):
return cls(task)

View File

@ -454,7 +454,7 @@ class BufferedIterator(object):
logger.info(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
"number of workers may help."
"number of workers (--num-workers) may help."
)
self.warning_time = time.time()

View File

@ -14,8 +14,8 @@ import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
FairseqDecoder,
FairseqLanguageModel,
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
@register_model('roberta')
class RobertaModel(FairseqLanguageModel):
class RobertaModel(FairseqEncoderModel):
@classmethod
def hub_models(cls):
@ -116,12 +116,20 @@ class RobertaModel(FairseqLanguageModel):
if classification_head_name is not None:
features_only = True
x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs)
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x, extra
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
"""Register a classification head."""
if name in self.classification_heads:
@ -163,13 +171,23 @@ class RobertaModel(FairseqLanguageModel):
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + '.' if name != '' else ''
# rename decoder -> encoder before upgrading children modules
for k in list(state_dict.keys()):
if k.startswith(prefix + 'decoder'):
new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):]
state_dict[new_k] = state_dict[k]
del state_dict[k]
# upgrade children modules
super().upgrade_state_dict_named(state_dict, name)
prefix = name + '.' if name != '' else ''
current_head_names = [] if not hasattr(self, 'classification_heads') else \
self.classification_heads.keys()
# Handle new classification heads present in the state dict.
current_head_names = (
[] if not hasattr(self, 'classification_heads')
else self.classification_heads.keys()
)
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'):
@ -261,24 +279,15 @@ class RobertaClassificationHead(nn.Module):
return x
class RobertaEncoder(FairseqDecoder):
"""RoBERTa encoder.
Implements the :class:`~fairseq.models.FairseqDecoder` interface required
by :class:`~fairseq.models.FairseqLanguageModel`.
"""
class RobertaEncoder(FairseqEncoder):
"""RoBERTa encoder."""
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args
# RoBERTa is a sentence encoder model, so users will intuitively trim
# encoder layers. However, the implementation uses the fairseq decoder,
# so we fix here.
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
args.decoder_layers_to_keep = args.encoder_layers_to_keep
args.encoder_layers_to_keep = None
self.sentence_encoder = TransformerSentenceEncoder(
padding_idx=dictionary.pad(),

View File

@ -36,7 +36,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
return {
'transformer_lm.gbw.adaptive_huge': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2',
'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2',
'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2',
'transformer_lm.wmt19.en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2'),
'transformer_lm.wmt19.de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2'),
'transformer_lm.wmt19.ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2'),

View File

@ -870,6 +870,16 @@ class Trainer(object):
self.task.reduce_metrics(logging_outputs, self.get_criterion())
del logging_outputs
# extra warning for criterions that don't properly log a loss value
if "loss" not in agg:
if "loss" not in self._warn_once:
self._warn_once.add("loss")
logger.warning(
"Criterion.reduce_metrics did not log a 'loss' value, "
"which may break some functionality"
)
metrics.log_scalar("loss", -1)
# support legacy interface
if self.tpu:
logging_output = {}

View File

@ -356,7 +356,6 @@ def import_user_module(args):
if module_name not in sys.modules:
sys.path.insert(0, module_parent)
importlib.import_module(module_name)
sys.path.pop(0)
def softmax(x, dim: int, onnx_trace: bool = False):