mirror of
https://github.com/facebookresearch/fairseq.git
synced 2024-11-12 21:52:01 +03:00
Miscellaneous fixes (#2193)
Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/2193 Reviewed By: ngoyal2707 Differential Revision: D21748548 Pulled By: myleott fbshipit-source-id: d9f64540b55b4d427b3da6ad04a35f7b988b049a
This commit is contained in:
parent
2f7e3f3323
commit
8e48f45aa4
@ -4,13 +4,13 @@
|
||||
|
||||
Description | Parameters | Dataset | Model and Test set(s)
|
||||
---|---:|---|---
|
||||
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.bz2)
|
||||
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.bz2)
|
||||
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
|
||||
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2)
|
||||
|
||||
## Training an LM with adaptive inputs
|
||||
|
||||
First, see the general [language modeling README](../README.md) for instructions
|
||||
on preprocessing the WikiText-103 data.
|
||||
First, see the general [language modeling README](README.md) for instructions on
|
||||
preprocessing the WikiText-103 data.
|
||||
|
||||
Then use the following training command to train a model with adaptive inputs
|
||||
using the `transformer_lm_wiki103` model architecture:
|
@ -2,8 +2,7 @@
|
||||
|
||||
## Example usage
|
||||
|
||||
First download and preprocess the data following the main [language modeling
|
||||
README](../README.md).
|
||||
First download and preprocess the data following the main [language modeling README](README.md).
|
||||
|
||||
Then to train a convolutional LM using the `fconv_lm_dauphin_wikitext103`
|
||||
architecture:
|
@ -5,7 +5,7 @@
|
||||
Model | Description | Dataset | Download
|
||||
---|---|---|---
|
||||
`transformer_lm.gbw.adaptive_huge` | Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) <br> 1026M params | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
|
||||
`transformer_lm.wiki103.adaptive` | Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) <br> 247M params | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2)
|
||||
`transformer_lm.wiki103.adaptive` | Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) <br> 247M params | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2)
|
||||
`transformer_lm.wmt19.en` | English LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.gz)
|
||||
`transformer_lm.wmt19.de` | German LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.gz)
|
||||
`transformer_lm.wmt19.ru` | Russian LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.gz)
|
||||
@ -72,8 +72,7 @@ fairseq-preprocess \
|
||||
### 2) Train a language model
|
||||
|
||||
Next we'll train a basic transformer language model on wikitext-103. For more
|
||||
advanced examples (e.g., using [adaptive inputs](https://arxiv.org/abs/1809.10853)),
|
||||
please see the [Transformer LM README](transformer_lm/README.md).
|
||||
advanced usage, see the [adaptive inputs README](README.adaptive_inputs.md).
|
||||
|
||||
To train a basic LM (assumes 2 GPUs):
|
||||
```
|
||||
@ -120,5 +119,5 @@ dataset, but results in better (lower) perplexity.
|
||||
|
||||
## Convolutional language models
|
||||
|
||||
Please see the [convolutional LM README](conv_lm/README.md) for instructions to
|
||||
train convolutional language models.
|
||||
Please see the [convolutional LM README](README.conv.md) for instructions on
|
||||
training convolutional language models.
|
||||
|
@ -26,7 +26,11 @@ Model | Description | Download
|
||||
Evaluate performance of these pre-trained models:
|
||||
```bash
|
||||
# Example for Machine Translation
|
||||
python generate.py /path/to/bped/wmt/data --path nmt_checkpoint.pt --lenpen 0.4 --batch-size 64 --remove-bpe --beam 8 --gen-subset test > wmt16_gen.txt
|
||||
fairseq-generate /path/to/bped/wmt/data --path nmt_checkpoint.pt \
|
||||
--beam 8 --lenpen 0.4 \
|
||||
--batch-size 64 \
|
||||
--remove-bpe \
|
||||
--gen-subset test > wmt16_gen.txt
|
||||
bash scripts/compound_split_bleu.sh wmt16_gen.txt
|
||||
# prints BLEU4 = 30.17
|
||||
```
|
||||
@ -111,8 +115,10 @@ num. model params: 146163712
|
||||
```
|
||||
|
||||
If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling:
|
||||
```
|
||||
python eval_lm.py /path/to/wikitext-103 --path '/path/to/model/checkpoint' --model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}"
|
||||
```bash
|
||||
fairseq-eval-lm /path/to/wikitext-103 \
|
||||
--path /path/to/model/checkpoint.pt \
|
||||
--model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}"
|
||||
```
|
||||
This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model.
|
||||
|
||||
|
@ -70,7 +70,7 @@ class ASGCriterion(FairseqCriterion):
|
||||
self.linseg_maximum = linseg_updates
|
||||
self.linseg_message_state = "none" if hide_linseg_messages else "start"
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def build_criterion(cls, args, task):
|
||||
return cls(
|
||||
task,
|
||||
|
@ -81,7 +81,7 @@ class CTCCriterion(FairseqCriterion):
|
||||
super().__init__(task)
|
||||
self.blank_idx = task.target_dictionary.index("<ctc_blank>")
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def build_criterion(cls, args, task):
|
||||
return cls(task)
|
||||
|
||||
|
@ -454,7 +454,7 @@ class BufferedIterator(object):
|
||||
logger.info(
|
||||
"Data loading buffer is empty or nearly empty. This may "
|
||||
"indicate a data loading bottleneck, and increasing the "
|
||||
"number of workers may help."
|
||||
"number of workers (--num-workers) may help."
|
||||
)
|
||||
self.warning_time = time.time()
|
||||
|
||||
|
@ -14,8 +14,8 @@ import torch.nn.functional as F
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.models import (
|
||||
FairseqDecoder,
|
||||
FairseqLanguageModel,
|
||||
FairseqEncoder,
|
||||
FairseqEncoderModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_model('roberta')
|
||||
class RobertaModel(FairseqLanguageModel):
|
||||
class RobertaModel(FairseqEncoderModel):
|
||||
|
||||
@classmethod
|
||||
def hub_models(cls):
|
||||
@ -116,12 +116,20 @@ class RobertaModel(FairseqLanguageModel):
|
||||
if classification_head_name is not None:
|
||||
features_only = True
|
||||
|
||||
x, extra = self.decoder(src_tokens, features_only, return_all_hiddens, **kwargs)
|
||||
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
|
||||
|
||||
if classification_head_name is not None:
|
||||
x = self.classification_heads[classification_head_name](x)
|
||||
return x, extra
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||
logits = net_output[0].float()
|
||||
if log_probs:
|
||||
return F.log_softmax(logits, dim=-1)
|
||||
else:
|
||||
return F.softmax(logits, dim=-1)
|
||||
|
||||
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
|
||||
"""Register a classification head."""
|
||||
if name in self.classification_heads:
|
||||
@ -163,13 +171,23 @@ class RobertaModel(FairseqLanguageModel):
|
||||
return RobertaHubInterface(x['args'], x['task'], x['models'][0])
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
prefix = name + '.' if name != '' else ''
|
||||
|
||||
# rename decoder -> encoder before upgrading children modules
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix + 'decoder'):
|
||||
new_k = prefix + 'encoder' + k[len(prefix + 'decoder'):]
|
||||
state_dict[new_k] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
# upgrade children modules
|
||||
super().upgrade_state_dict_named(state_dict, name)
|
||||
|
||||
prefix = name + '.' if name != '' else ''
|
||||
current_head_names = [] if not hasattr(self, 'classification_heads') else \
|
||||
self.classification_heads.keys()
|
||||
|
||||
# Handle new classification heads present in the state dict.
|
||||
current_head_names = (
|
||||
[] if not hasattr(self, 'classification_heads')
|
||||
else self.classification_heads.keys()
|
||||
)
|
||||
keys_to_delete = []
|
||||
for k in state_dict.keys():
|
||||
if not k.startswith(prefix + 'classification_heads.'):
|
||||
@ -261,24 +279,15 @@ class RobertaClassificationHead(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class RobertaEncoder(FairseqDecoder):
|
||||
"""RoBERTa encoder.
|
||||
|
||||
Implements the :class:`~fairseq.models.FairseqDecoder` interface required
|
||||
by :class:`~fairseq.models.FairseqLanguageModel`.
|
||||
"""
|
||||
class RobertaEncoder(FairseqEncoder):
|
||||
"""RoBERTa encoder."""
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(dictionary)
|
||||
self.args = args
|
||||
|
||||
# RoBERTa is a sentence encoder model, so users will intuitively trim
|
||||
# encoder layers. However, the implementation uses the fairseq decoder,
|
||||
# so we fix here.
|
||||
if args.encoder_layers_to_keep:
|
||||
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
|
||||
args.decoder_layers_to_keep = args.encoder_layers_to_keep
|
||||
args.encoder_layers_to_keep = None
|
||||
|
||||
self.sentence_encoder = TransformerSentenceEncoder(
|
||||
padding_idx=dictionary.pad(),
|
||||
|
@ -36,7 +36,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
|
||||
|
||||
return {
|
||||
'transformer_lm.gbw.adaptive_huge': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2',
|
||||
'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2',
|
||||
'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2',
|
||||
'transformer_lm.wmt19.en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2'),
|
||||
'transformer_lm.wmt19.de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2'),
|
||||
'transformer_lm.wmt19.ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2'),
|
||||
|
@ -870,6 +870,16 @@ class Trainer(object):
|
||||
self.task.reduce_metrics(logging_outputs, self.get_criterion())
|
||||
del logging_outputs
|
||||
|
||||
# extra warning for criterions that don't properly log a loss value
|
||||
if "loss" not in agg:
|
||||
if "loss" not in self._warn_once:
|
||||
self._warn_once.add("loss")
|
||||
logger.warning(
|
||||
"Criterion.reduce_metrics did not log a 'loss' value, "
|
||||
"which may break some functionality"
|
||||
)
|
||||
metrics.log_scalar("loss", -1)
|
||||
|
||||
# support legacy interface
|
||||
if self.tpu:
|
||||
logging_output = {}
|
||||
|
@ -356,7 +356,6 @@ def import_user_module(args):
|
||||
if module_name not in sys.modules:
|
||||
sys.path.insert(0, module_parent)
|
||||
importlib.import_module(module_name)
|
||||
sys.path.pop(0)
|
||||
|
||||
|
||||
def softmax(x, dim: int, onnx_trace: bool = False):
|
||||
|
Loading…
Reference in New Issue
Block a user