Enable final stack post-processing for transformer for correct prenorm behavior (#719)

This PR enables final post-processing of a full transformer stack for correct prenorm behavior.
See issues: #715 and #699,

List of changes:

Add final post-processing in encoder and decoder if requested with --transformer-postprocess-top. Can take combinations of d, n, a. Using a will add a skip connection from the bottom of the stack.
Add --task transformer-base-prenorm and --task transformer-big-prenorm which correspond to --task transformer-base --transformer-preprocess n --transformer-postprocess da --transformer-postprocess-top n.
This commit is contained in:
Marcin Junczys-Dowmunt 2020-09-09 08:06:20 -07:00 committed by GitHub
parent 660719cd27
commit 951ecfe932
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 105 additions and 0 deletions

View File

@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- Add --transformer-postprocess-top option to enable correctly normalized prenorm behavior
- Add --task transformer-base-prenorm and --task transformer-big-prenorm
- Turing and Ampere GPU optimisation support, if the CUDA version supports it.
- Printing word-level scores in marian-scorer
- Optimize LayerNormalization on CPU by 6x through vectorization (ffast-math) and fixing performance regression introduced with strides in 77a420

View File

@ -144,6 +144,87 @@ void ConfigParser::addAliases(cli::CLIWrapper& cli) {
config["valid-mini-batch"] = 8;
config["normalize"] = 1.0;
});
// Transformer base variant with "prenorm" (i.e. the layer normalization is performed as the first block-wise
// preprocessing step). This also requires to normalize the final output of a transformer stack to avoid the
// activations to blow up. This blow up is particularly nasty with mixed precision training.
// See implementation and comments in tensor2tensor:
// * https://github.com/tensorflow/tensor2tensor/blob/95d021477272c10af15cd62f25b595ad16ad514e/tensor2tensor/models/transformer.py#L1845
// * https://github.com/tensorflow/tensor2tensor/commit/f5c9b17e617ea9179b7d84d36b1e8162cb369f25#diff-4e58a582cf11ca649e76b4362d69e405R78
cli.alias("task", "transformer-base-prenorm", [](YAML::Node& config) {
// Model options
config["type"] = "transformer";
config["enc-depth"] = 6;
config["dec-depth"] = 6;
config["dim-emb"] = 512;
config["tied-embeddings-all"] = true;
config["transformer-dim-ffn"] = 2048;
config["transformer-heads"] = 8;
config["transformer-postprocess"] = "da"; // change from transformer-base is "dan" -> "da"
config["transformer-preprocess"] = "n"; // change from transformer-base is "" -> "n"
config["transformer-postprocess-top"] = "n"; // change from transformer-base is "" -> "n"
config["transformer-ffn-activation"] = "relu";
config["transformer-dropout"] = 0.1;
// Training specific options
config["learn-rate"] = 0.0003;
config["cost-type"] = "ce-mean-words";
config["lr-warmup"] = 16000;
config["lr-decay-inv-sqrt"] = 16000;
config["label-smoothing"] = 0.1;
config["clip-norm"] = 0;
config["sync-sgd"] = true;
config["exponential-smoothing"] = 1e-4;
config["max-length"] = 100;
config["mini-batch-fit"] = true;
config["mini-batch"] = 1000;
config["maxi-batch"] = 1000;
config["workspace"] = 9500;
config["optimizer-params"] = std::vector<float>({0.9f, 0.98f, 1e-09f});
// Validation specific options
config["beam-size"] = 8;
config["valid-mini-batch"] = 16;
config["normalize"] = 1.0;
});
// Transformer big variant with "prenorm". Same changes as above.
cli.alias("task", "transformer-big-prenorm", [](YAML::Node& config) {
// Model options
config["type"] = "transformer";
config["enc-depth"] = 6;
config["dec-depth"] = 6;
config["dim-emb"] = 1024;
config["tied-embeddings-all"] = true;
config["transformer-dim-ffn"] = 4096;
config["transformer-heads"] = 16;
config["transformer-postprocess"] = "da"; // change from transformer-big is "dan" -> "da"
config["transformer-preprocess"] = "n"; // change from transformer-big is "" -> "n"
config["transformer-postprocess-top"] = "n"; // change from transformer-big is "" -> "n"
config["transformer-ffn-activation"] = "relu";
config["transformer-dropout"] = 0.1;
// Training specific options
config["learn-rate"] = 0.0002;
config["cost-type"] = "ce-mean-words";
config["lr-warmup"] = 8000;
config["lr-decay-inv-sqrt"] = 8000;
config["label-smoothing"] = 0.1;
config["clip-norm"] = 0;
config["sync-sgd"] = true;
config["exponential-smoothing"] = 1e-4;
config["max-length"] = 100;
config["mini-batch-fit"] = true;
config["mini-batch"] = 1000;
config["maxi-batch"] = 1000;
config["workspace"] = 13000;
config["optimizer-params"] = std::vector<float>({0.9f, 0.998f, 1e-09f});
// Validation specific options
config["beam-size"] = 8;
config["valid-mini-batch"] = 8;
config["normalize"] = 1.0;
});
}
}

View File

@ -288,6 +288,9 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
cli.add<std::string>("--transformer-postprocess",
"Operation after each transformer layer: d = dropout, a = add, n = normalize",
"dan");
cli.add<std::string>("--transformer-postprocess-top",
"Final operation after a full transformer stack: d = dropout, a = add, n = normalize. The optional skip connection with 'a' by-passes the entire stack.",
"");
cli.add<bool>("--transformer-train-position-embeddings",
"Train positional embeddings instead of using static sinusoidal embeddings");
cli.add<bool>("--transformer-depth-scaling",

View File

@ -126,6 +126,7 @@ public:
modelFeatures_.insert("transformer-preprocess");
modelFeatures_.insert("transformer-postprocess");
modelFeatures_.insert("transformer-postprocess-emb");
modelFeatures_.insert("transformer-postprocess-top");
modelFeatures_.insert("transformer-decoder-autoreg");
modelFeatures_.insert("transformer-tied-layers");
modelFeatures_.insert("transformer-guided-alignment-layer");

View File

@ -47,6 +47,7 @@ EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options)
modelFeatures_.insert("transformer-preprocess");
modelFeatures_.insert("transformer-postprocess");
modelFeatures_.insert("transformer-postprocess-emb");
modelFeatures_.insert("transformer-postprocess-top");
modelFeatures_.insert("transformer-decoder-autoreg");
modelFeatures_.insert("transformer-tied-layers");
modelFeatures_.insert("transformer-guided-alignment-layer");

View File

@ -135,6 +135,7 @@ public:
modelFeatures_.insert("transformer-preprocess");
modelFeatures_.insert("transformer-postprocess");
modelFeatures_.insert("transformer-postprocess-emb");
modelFeatures_.insert("transformer-postprocess-top");
modelFeatures_.insert("transformer-decoder-autoreg");
modelFeatures_.insert("transformer-tied-layers");
modelFeatures_.insert("transformer-guided-alignment-layer");

View File

@ -558,6 +558,8 @@ public:
auto layer = transposeTimeBatch(batchEmbeddings); // [beam depth=1, batch size, max length, vector dim]
auto layerMask = transposeTimeBatch(batchMask); // [beam depth=1, batch size, max length, vector dim=1]
auto prevLayer = layer; // keep handle to untransformed embeddings, potentially used for a final skip connection
auto opsEmb = opt<std::string>("transformer-postprocess-emb");
float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
layer = preProcess(prefix_ + "_emb", opsEmb, layer, dropProb);
@ -580,6 +582,12 @@ public:
checkpoint(layer); // sets a manually specified checkpoint if gradient checkpointing is enabled, does nothing otherwise.
}
// this allows to run a final layernorm operation after going through the transformer layer stack.
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
// it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested.
auto opsTop = opt<std::string>("transformer-postprocess-top", "");
layer = postProcess(prefix_ + "_top", opsTop, layer, prevLayer, dropProb);
// restore organization of batch and time steps. This is currently required
// to make RNN-based decoders and beam search work with this. We are looking
// into making this more natural.
@ -706,6 +714,8 @@ public:
// reorganize batch and timestep
auto query = transposeTimeBatch(scaledEmbeddings); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
auto prevQuery = query; // keep handle to untransformed embeddings, potentially used for a final skip connection
auto opsEmb = opt<std::string>("transformer-postprocess-emb");
float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
@ -841,6 +851,12 @@ public:
checkpoint(query);
}
// This allows to run a final layernorm operation after going through the transformer layer stack.
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
// it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested.
auto opsTop = opt<std::string>("transformer-postprocess-top", "");
query = postProcess(prefix_ + "_top", opsTop, query, prevQuery, dropProb);
auto decoderContext = transposeTimeBatch(query); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vector dim]
//************************************************************************//