adding layerdrop code for training, pruning, and readme (#890)

Summary:
TEST 1: EVALUATION TIME WORKS
checked
achieves correct model perplexity: 18.68

TEST 2: TRAINING NEW MODEL WORKS
checked

without layerdrop:
--decoder-layerdrop 0 OR no flag at all
| epoch 001:     10 / 11201 loss=27.469, nll_loss=27.469, ppl=185799477.36, wps=1764, ups=0, wpb=9216.000, bsz=3.000, num_updates=7, lr=0.0004376, gnorm=25.471, clip=1.000, oom=0.000, loss_scale=8.000, wall=37, train_wall=30
| epoch 001:     20 / 11201 loss=27.443, nll_loss=27.443, ppl=182500427.22, wps=2449, ups=0, wpb=9216.000, bsz=3.000, num_updates=17, lr=0.0010626, gnorm=25.273, clip=1.000, oom=0.000, loss_scale=8.000, wall=64, train_wall=57
| epoch 001:     30 / 11201 loss=27.404, nll_loss=27.404, ppl=177612215.78, wps=2720, ups=0, wpb=9216.000, bsz=3.000, num_updates=27, lr=0.0016876, gnorm=25.136, clip=1.000, oom=0.000, loss_scale=8.000, wall=91, train_wall=84
| epoch 001:     40 / 11201 loss=27.009, nll_loss=27.009, ppl=135079983.00, wps=2865, ups=0, wpb=9216.000, bsz=3.000, num_updates=37, lr=0.0023126, gnorm=24.311, clip=1.000, oom=0.000, loss_scale=8.000, wall=119, train_wall=112
| epoch 001:     50 / 11201 loss=26.418, nll_loss=26.418, ppl=89680259.41, wps=2952, ups=0, wpb=9216.000, bsz=3.000, num_updates=47, lr=0.0029376, gnorm=22.775, clip=1.000, oom=0.000, loss_scale=8.000, wall=147, train_wall=140

with layerdrop (regularization effect should be seen in PPL):
--decoder-layerdrop 0.2

| epoch 001:     10 / 11201 loss=25.186, nll_loss=25.186, ppl=38182937.27, wps=2428, ups=0, wpb=9216.000, bsz=3.000, num_updates=8, lr=0.0005001, gnorm=17.082, clip=1.000, oom=0.000, loss_scale=16.000, wall=30, train_wall=24
| epoch 001:     20 / 11201 loss=25.270, nll_loss=25.270, ppl=40451933.50, wps=3173, ups=0, wpb=9216.000, bsz=3.000, num_updates=18, lr=0.0011251, gnorm=17.162, clip=1.000, oom=0.000, loss_scale=16.000, wall=52, train_wall=45
| epoch 001:     30 / 11201 loss=25.349, nll_loss=25.349, ppl=42752256.68, wps=3454, ups=0, wpb=9216.000, bsz=3.000, num_updates=28, lr=0.0017501, gnorm=17.370, clip=1.000, oom=0.000, loss_scale=16.000, wall=75, train_wall=68
| epoch 001:     40 / 11201 loss=25.115, nll_loss=25.115, ppl=36343806.30, wps=3619, ups=0, wpb=9216.000, bsz=3.000, num_updates=38, lr=0.0023751, gnorm=16.945, clip=1.000, oom=0.000, loss_scale=16.000, wall=97, train_wall=90
| epoch 001:     50 / 11201 loss=24.804, nll_loss=24.804, ppl=29284345.78, wps=3716, ups=0, wpb=9216.000, bsz=3.000, num_updates=48, lr=0.0030001, gnorm=16.406, clip=1.000, oom=0.000, loss_scale=16.000, wall=119, train_wall=112

TEST 3: PICKING UP TRAINING FROM EXISTING MODEL
checked

| loaded checkpoint /checkpoint/angelafan/structured_0.1_block_8_sd02/checkpoint_last.pt (epoch 272 @ 381066 updates)
| loading train data for epoch 272
| loaded 1801350 examples from: /private/home/angelafan/lm_work/fairseq-py/data-bin/wikitext-103/train

TEST 4: EVALUATING EXISTING BERT MODEL REPROS RESULTS
| [input] dictionary: 50265 types
| [label] dictionary: 9 types
| Accuracy:  0.9231651376146789
achieves correct accuracy on SST2 for this model

TEST 5: TRAINING NEW BERT MODEL WORKS
checked and works

TEST 6: NMT

without layerdrop
--encoder-layerdrop 0 --decoder-layerdrop 0 OR combinations of flag specified and not specified

| epoch 001:     10 / 92203 loss=15.820, nll_loss=15.830, ppl=58267.93, wps=4902, ups=0, wpb=1477.818, bsz=51.636, num_updates=11, lr=1.47473e-06, gnorm=7.207, clip=0.000, oom=0.000, loss_scale=128.000, wall=60, train_wall=3
| epoch 001:     20 / 92203 loss=15.523, nll_loss=15.501, ppl=46359.29, wps=5037, ups=0, wpb=1496.476, bsz=45.333, num_updates=21, lr=2.72448e-06, gnorm=6.869, clip=0.000, oom=0.000, loss_scale=128.000, wall=63, train_wall=6
| epoch 001:     30 / 92203 loss=15.185, nll_loss=15.123, ppl=35695.79, wps=5085, ups=0, wpb=1519.355, bsz=44.645, num_updates=31, lr=3.97423e-06, gnorm=6.186, clip=0.000, oom=0.000, loss_scale=128.000, wall=66, train_wall=9
| epoch 001:     40 / 92203 loss=14.940, nll_loss=14.849, ppl=29505.60, wps=5116, ups=1, wpb=1521.244, bsz=42.927, num_updates=41, lr=5.22398e-06, gnorm=5.610, clip=0.000, oom=0.000, loss_scale=128.000, wall=69, train_wall=12
| epoch 001:     50 / 92203 loss=14.745, nll_loss=14.630, ppl=25346.87, wps=5070, ups=1, wpb=1507.961, bsz=41.725, num_updates=51, lr=6.47373e-06, gnorm=5.104, clip=0.000, oom=0.000, loss_scale=128.000, wall=71, train_wall=15

with layerdrop (regularization effect should be seen in PPL)

A) works with --encoder-layerdrop 0.2 --decoder-layerdrop 0.2
B) works with different settings --encoder-layerdrop 0.3 --decoder-layerdrop 0.5
C) works with one on and one off --encoder-layerdrop 0.2 --decoder-layerdrop 0

| epoch 001:     10 / 92203 loss=15.817, nll_loss=15.828, ppl=58158.54, wps=5355, ups=0, wpb=1477.818, bsz=51.636, num_updates=11, lr=1.47473e-06, gnorm=6.959, clip=0.000, oom=0.000, loss_scale=128.000, wall=59, train_wall=3
| epoch 001:     20 / 92203 loss=15.650, nll_loss=15.641, ppl=51111.63, wps=5515, ups=0, wpb=1496.476, bsz=45.333, num_updates=21, lr=2.72448e-06, gnorm=6.825, clip=0.000, oom=0.000, loss_scale=128.000, wall=61, train_wall=6
| epoch 001:     30 / 92203 loss=15.440, nll_loss=15.408, ppl=43491.58, wps=5602, ups=0, wpb=1519.355, bsz=44.645, num_updates=31, lr=3.97423e-06, gnorm=6.576, clip=0.000, oom=0.000, loss_scale=128.000, wall=64, train_wall=8
| epoch 001:     40 / 92203 loss=15.247, nll_loss=15.193, ppl=37457.14, wps=5676, ups=1, wpb=1521.244, bsz=42.927, num_updates=41, lr=5.22398e-06, gnorm=6.124, clip=0.000, oom=0.000, loss_scale=128.000, wall=67, train_wall=11
| epoch 001:     50 / 92203 loss=15.055, nll_loss=14.977, ppl=32259.92, wps=5598, ups=1, wpb=1507.961, bsz=41.725, num_updates=51, lr=6.47373e-06, gnorm=5.661, clip=0.000, oom=0.000, loss_scale=128.000, wall=69, train_wall=14

TEST 7: PRUNING TESTCASES

A) after adding the pruning flags, model can evaluate as a full model
checked, reaches correct PPL
num. model params: 246933504
| Evaluated 217646 tokens in 196.3s (1108.99 tokens/s)
| Loss: 2.9275, Perplexity: 18.68

B) after adding pruning flags, model can be pruned. this works with multiple flag settings
checked three cases:
num. model params: 146163712
| Evaluated 217646 tokens in 106.0s (2054.07 tokens/s)
| Loss: 3.0932, Perplexity: 22.05

num. model params: 209144832
| Evaluated 217646 tokens in 162.8s (1336.99 tokens/s)
| Loss: 2.9526, Perplexity: 19.16

C) model can pick up training if you want to finetune the pruned model
checked:
| loading train data for epoch 272
| loaded 1801350 examples from: /private/home/angelafan/lm_work/fairseq-py/data-bin/wikitext-103/train
| WARNING: overflow detected, setting loss scale to: 64.0
| WARNING: overflow detected, setting loss scale to: 32.0
| epoch 272:   1500 / 5601 loss=5.015, nll_loss=5.015, ppl=32.33, wps=11598, ups=1, wpb=18432.000, bsz=6.000, num_updates=98, lr=0.0061251, gnorm=0.613, clip=1.000, oom=0.000, loss_scale=32.000, wall=156, train_wall=252396

D) works with BERT
checked:
without specifying any flags, reproduces the correct standard accuracy
with flags, produces the correct pruned accuracy

| [input] dictionary: 50265 types
| [label] dictionary: 9 types
| Accuracy:  0.9231651376146789

| [input] dictionary: 50265 types
| [label] dictionary: 9 types
| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop
| Accuracy:  0.9220183486238532
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/890

Reviewed By: edunov

Differential Revision: D18094657

Pulled By: huihuifan

fbshipit-source-id: 2bbaa2ff0039e906782694fc2038b8c17a8693e7
This commit is contained in:
Angela Fan 2019-10-27 12:09:29 -07:00 committed by Facebook Github Bot
parent eb68afca02
commit dabbef4676
8 changed files with 209 additions and 24 deletions

View File

@ -0,0 +1,66 @@
# Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)
This page contains information for how to train models with LayerDrop.
Looking for pretrained models? They will be added shortly.
Looking for code for other forms of Structured Dropout? It will be added shortly.
## Citation:
```bibtex
@article{fan2019reducing,
title={Reducing Transformer Depth on Demand with Structured Dropout},
author={Fan, Angela and Grave, Edouard and Joulin, Armand},
journal={arXiv preprint arXiv:1909.11556},
year={2019}
}
```
## Example usage
To train a model with LayerDrop, add the following flags. We recommend 0.2, a value that worked well in our experiments. For Language Models that are decoder-only, you need only the decoder flag. For RoBERTa, an encoder, you need only the encoder flag. The encoder and decoder LayerDrop values can be set differently.
```
--encoder-layerdrop 0.2 --decoder-layerdrop 0.2
```
To prune a model that has been trained with LayerDrop, add the following flags followed by a comma separated list of which layers you would like to keep.
```
--encoder-layers-to-keep 0,2,4,6,8,10,12,14 --decoder-layers-to-keep 0,2,4,6,8,10,12,14
```
Setting these flags should print a message such as:
```
| Pruning model to specified layer configuration
```
You should also see a smaller number of parameters in the model, for example the 16-Layer Transformer Language Model prints:
```
num. model params: 246933504
```
while a model pruned to 8 Layers prints:
```
num. model params: 146163712
```
If you would like to pick up training with a model that has been pruned, simply adding these flags is sufficient. If you would like to use a script that only does evaluation (no training), you may need to pass an override command. A specific example would be for language modeling:
```
python eval_lm.py /path/to/wikitext-103 --path '/path/to/model/checkpoint' --model-overrides "{'decoder_layers_to_keep':'0,2,4,6,8,10,12,14'}"
```
This model override command overrides the training parameters and updates the model arguments so that the pruned model is run instead of the full model.
Looking to reproduce the results in the paper?
1. For Translation on WMT en-de, we followed this setting [here](https://github.com/pytorch/fairseq/blob/master/examples/scaling_nmt/README.md)
2. To train RoBERTa, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/roberta)
3. To train Language Models on Wikitext-103, we followed this setting [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model)
## Tips
1. If you would like to train large models with better performance, LayerDrop should be set to a smaller value such as 0.1 or 0.2. Too much LayerDrop will mean the model has too much regularization, so may not reach the best performance. Since LayerDrop adds regularization, you may achieve the best performance by slightly reducing the amount of standard dropout (for example, reduce by 0.1).
2. If you would like to train large models to be pruned and made smaller, LayerDrop should be set to a larger value such as 0.5 if you want to prune very aggressively (such as removing half the network or more). If you would like to prune fewer layers away, LayerDrop can be set to a smaller value such as 0.2.
3. When pruning layers at inference time, it is best to spread out the layers remaining so they are evenly spaced throughout the network. For example, if you want to remove 50% of the network, keeping every other layer is good.
## Having an issue or have a question?
Please open an issue in this repository with the details of your question. Thanks!

View File

@ -183,7 +183,7 @@ def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None):
# build model for ensemble
model = task.build_model(args)
model.load_state_dict(state['model'], strict=True)
model.load_state_dict(state['model'], strict=True, args=args)
ensemble.append(model)
return ensemble, args, task
@ -334,6 +334,70 @@ def _upgrade_state_dict(state):
return state
def prune_state_dict(state_dict, args):
"""Prune the given state_dict if desired for LayerDrop
(https://arxiv.org/abs/1909.11556).
Training with LayerDrop allows models to be robust to pruning at inference
time. This function prunes state_dict to allow smaller models to be loaded
from a larger model and re-maps the existing state_dict for this to occur.
It's called by functions that load models from checkpoints and does not
need to be called directly.
"""
if not args:
# args should not be none, but don't crash if it is.
return state_dict
encoder_layers_to_keep = args.encoder_layers_to_keep if "encoder_layers_to_keep" in vars(args) else None
decoder_layers_to_keep = args.decoder_layers_to_keep if "decoder_layers_to_keep" in vars(args) else None
if not encoder_layers_to_keep and not decoder_layers_to_keep:
return state_dict
# apply pruning
print("| Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop")
def create_pruning_pass(layers_to_keep, layer_name):
keep_layers = sorted([int(layer_string) for layer_string in layers_to_keep.split(",")])
mapping_dict = {}
for i in range(len(keep_layers)):
mapping_dict[str(keep_layers[i])] = str(i)
regex = re.compile("^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
return {
"substitution_regex": regex,
"mapping_dict": mapping_dict
}
pruning_passes = []
if encoder_layers_to_keep:
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
if decoder_layers_to_keep:
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
new_state_dict = {}
for layer_name in state_dict.keys():
match = re.search("\.layers\.(\d+)\.", layer_name)
# if layer has no number in it, it is a supporting layer, such as an
# embedding
if not match:
new_state_dict[layer_name] = state_dict[layer_name]
continue
# otherwise, layer should be pruned.
original_layer_number = match.group(1)
# figure out which mapping dict to replace from
for pruning_pass in pruning_passes:
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
substitution_match = pruning_pass["substitution_regex"].search(layer_name)
new_state_key = layer_name[:substitution_match.start(1)] + new_layer_number + layer_name[substitution_match.end(1):]
new_state_dict[new_state_key] = state_dict[layer_name]
return new_state_dict
def load_pretrained_component_from_model(
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
):

View File

@ -13,6 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.checkpoint_utils import prune_state_dict
from fairseq.data import Dictionary
from fairseq.models import FairseqDecoder, FairseqEncoder
@ -58,7 +59,7 @@ class BaseFairseqModel(nn.Module):
"""Maximum length supported by the model."""
return None
def load_state_dict(self, state_dict, strict=True):
def load_state_dict(self, state_dict, strict=True, args=None):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
@ -66,7 +67,8 @@ class BaseFairseqModel(nn.Module):
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
self.upgrade_state_dict(state_dict)
return super().load_state_dict(state_dict, strict)
new_state_dict = prune_state_dict(state_dict, args)
return super().load_state_dict(new_state_dict, strict)
def upgrade_state_dict(self, state_dict):
"""Upgrade old state dicts to work with newer code."""

View File

@ -78,6 +78,11 @@ class RobertaModel(FairseqLanguageModel):
help='number of positional embeddings to learn')
parser.add_argument('--load-checkpoint-heads', action='store_true',
help='(re-)register and load heads when loading checkpoints')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
@classmethod
def build_model(cls, args, task):
@ -245,6 +250,15 @@ class RobertaEncoder(FairseqDecoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args
# RoBERTa is a sentence encoder model, so users will intuitively trim
# encoder layers. However, the implementation uses the fairseq decoder,
# so we fix here.
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
args.decoder_layers_to_keep = args.encoder_layers_to_keep
args.encoder_layers_to_keep = None
self.sentence_encoder = TransformerSentenceEncoder(
padding_idx=dictionary.pad(),
vocab_size=len(dictionary),
@ -255,6 +269,7 @@ class RobertaEncoder(FairseqDecoder):
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
layerdrop=args.encoder_layerdrop,
max_seq_len=args.max_positions,
num_segments=0,
encoder_normalize_before=True,

View File

@ -25,6 +25,7 @@ from fairseq.modules import (
TransformerDecoderLayer,
TransformerEncoderLayer,
)
import random
DEFAULT_MAX_SOURCE_POSITIONS = 1024
DEFAULT_MAX_TARGET_POSITIONS = 1024
@ -130,6 +131,15 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='perform cross+self-attention')
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
help='perform layer-wise attention (cross-attention or cross+self-attention)')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--encoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for encoder')
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--encoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# fmt: on
@classmethod
@ -139,6 +149,11 @@ class TransformerModel(FairseqEncoderDecoderModel):
# make sure all arguments are present in older models
base_architecture(args)
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if not hasattr(args, 'max_target_positions'):
@ -275,6 +290,7 @@ class TransformerEncoder(FairseqEncoder):
self.register_buffer('version', torch.Tensor([3]))
self.dropout = args.dropout
self.encoder_layerdrop = args.encoder_layerdrop
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
@ -300,6 +316,7 @@ class TransformerEncoder(FairseqEncoder):
else:
self.layer_norm = None
def forward_embedding(self, src_tokens):
# embed tokens and positions
embed = self.embed_scale * self.embed_tokens(src_tokens)
@ -345,9 +362,12 @@ class TransformerEncoder(FairseqEncoder):
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not self.training or (dropout_probability > self.encoder_layerdrop):
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)
if self.layer_norm:
x = self.layer_norm(x)
@ -435,6 +455,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.register_buffer('version', torch.Tensor([3]))
self.dropout = args.dropout
self.decoder_layerdrop = args.decoder_layerdrop
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
@ -594,20 +615,22 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else:
self_attn_mask = None
x, layer_attn = layer(
x,
encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=(idx == alignment_layer),
need_head_weights=(idx == alignment_layer),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float()
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not self.training or (dropout_probability > self.decoder_layerdrop):
x, layer_attn = layer(
x,
encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=(idx == alignment_layer),
need_head_weights=(idx == alignment_layer),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float()
if attn is not None:
if alignment_heads is not None:

View File

@ -98,6 +98,11 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='if set, ties the projection weights of adaptive softmax and adaptive input')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument('--decoder-layerdrop', type=float, metavar='D', default=0,
help='LayerDrop probability for decoder')
parser.add_argument('--decoder-layers-to-keep', default=None,
help='which layers to *keep* when pruning as a comma-separated list')
# fmt: on
@classmethod
@ -107,6 +112,9 @@ class TransformerLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models
base_lm_architecture(args)
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if getattr(args, 'max_target_positions', None) is None:
args.max_target_positions = getattr(args, 'tokens_per_sample', DEFAULT_MAX_TARGET_POSITIONS)

View File

@ -14,6 +14,7 @@ from fairseq.modules import (
PositionalEmbedding,
TransformerSentenceEncoderLayer,
)
import random
def init_bert_params(module):
@ -77,6 +78,7 @@ class TransformerSentenceEncoder(nn.Module):
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
layerdrop : float = 0.0,
max_seq_len: int = 256,
num_segments: int = 2,
use_position_embeddings: bool = True,
@ -97,6 +99,7 @@ class TransformerSentenceEncoder(nn.Module):
self.padding_idx = padding_idx
self.vocab_size = vocab_size
self.dropout = dropout
self.layerdrop = layerdrop
self.max_seq_len = max_seq_len
self.embedding_dim = embedding_dim
self.num_segments = num_segments
@ -208,9 +211,13 @@ class TransformerSentenceEncoder(nn.Module):
inner_states.append(x)
for layer in self.layers:
x, _ = layer(x, self_attn_padding_mask=padding_mask)
if not last_state_only:
inner_states.append(x)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not self.training or (dropout_probability > self.layerdrop):
x, _ = layer(x, self_attn_padding_mask=padding_mask)
if not last_state_only:
inner_states.append(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)

View File

@ -181,7 +181,7 @@ class Trainer(object):
# load model parameters
try:
self.get_model().load_state_dict(state['model'], strict=True)
self.get_model().load_state_dict(state['model'], strict=True, args=self.args)
if utils.has_parameters(self.get_criterion()):
self.get_criterion().load_state_dict(state['criterion'], strict=True)
except Exception: