mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Enable final stack post-processing for transformer for correct prenorm behavior (#719)
This PR enables final post-processing of a full transformer stack for correct prenorm behavior. See issues: #715 and #699, List of changes: Add final post-processing in encoder and decoder if requested with --transformer-postprocess-top. Can take combinations of d, n, a. Using a will add a skip connection from the bottom of the stack. Add --task transformer-base-prenorm and --task transformer-big-prenorm which correspond to --task transformer-base --transformer-preprocess n --transformer-postprocess da --transformer-postprocess-top n.
This commit is contained in:
parent
660719cd27
commit
951ecfe932
@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Add --transformer-postprocess-top option to enable correctly normalized prenorm behavior
|
||||
- Add --task transformer-base-prenorm and --task transformer-big-prenorm
|
||||
- Turing and Ampere GPU optimisation support, if the CUDA version supports it.
|
||||
- Printing word-level scores in marian-scorer
|
||||
- Optimize LayerNormalization on CPU by 6x through vectorization (ffast-math) and fixing performance regression introduced with strides in 77a420
|
||||
|
@ -144,6 +144,87 @@ void ConfigParser::addAliases(cli::CLIWrapper& cli) {
|
||||
config["valid-mini-batch"] = 8;
|
||||
config["normalize"] = 1.0;
|
||||
});
|
||||
|
||||
// Transformer base variant with "prenorm" (i.e. the layer normalization is performed as the first block-wise
|
||||
// preprocessing step). This also requires to normalize the final output of a transformer stack to avoid the
|
||||
// activations to blow up. This blow up is particularly nasty with mixed precision training.
|
||||
// See implementation and comments in tensor2tensor:
|
||||
// * https://github.com/tensorflow/tensor2tensor/blob/95d021477272c10af15cd62f25b595ad16ad514e/tensor2tensor/models/transformer.py#L1845
|
||||
// * https://github.com/tensorflow/tensor2tensor/commit/f5c9b17e617ea9179b7d84d36b1e8162cb369f25#diff-4e58a582cf11ca649e76b4362d69e405R78
|
||||
cli.alias("task", "transformer-base-prenorm", [](YAML::Node& config) {
|
||||
// Model options
|
||||
config["type"] = "transformer";
|
||||
config["enc-depth"] = 6;
|
||||
config["dec-depth"] = 6;
|
||||
config["dim-emb"] = 512;
|
||||
config["tied-embeddings-all"] = true;
|
||||
config["transformer-dim-ffn"] = 2048;
|
||||
config["transformer-heads"] = 8;
|
||||
config["transformer-postprocess"] = "da"; // change from transformer-base is "dan" -> "da"
|
||||
config["transformer-preprocess"] = "n"; // change from transformer-base is "" -> "n"
|
||||
config["transformer-postprocess-top"] = "n"; // change from transformer-base is "" -> "n"
|
||||
config["transformer-ffn-activation"] = "relu";
|
||||
config["transformer-dropout"] = 0.1;
|
||||
|
||||
// Training specific options
|
||||
config["learn-rate"] = 0.0003;
|
||||
config["cost-type"] = "ce-mean-words";
|
||||
config["lr-warmup"] = 16000;
|
||||
config["lr-decay-inv-sqrt"] = 16000;
|
||||
config["label-smoothing"] = 0.1;
|
||||
config["clip-norm"] = 0;
|
||||
config["sync-sgd"] = true;
|
||||
config["exponential-smoothing"] = 1e-4;
|
||||
config["max-length"] = 100;
|
||||
config["mini-batch-fit"] = true;
|
||||
config["mini-batch"] = 1000;
|
||||
config["maxi-batch"] = 1000;
|
||||
config["workspace"] = 9500;
|
||||
config["optimizer-params"] = std::vector<float>({0.9f, 0.98f, 1e-09f});
|
||||
|
||||
// Validation specific options
|
||||
config["beam-size"] = 8;
|
||||
config["valid-mini-batch"] = 16;
|
||||
config["normalize"] = 1.0;
|
||||
});
|
||||
|
||||
// Transformer big variant with "prenorm". Same changes as above.
|
||||
cli.alias("task", "transformer-big-prenorm", [](YAML::Node& config) {
|
||||
// Model options
|
||||
config["type"] = "transformer";
|
||||
config["enc-depth"] = 6;
|
||||
config["dec-depth"] = 6;
|
||||
config["dim-emb"] = 1024;
|
||||
config["tied-embeddings-all"] = true;
|
||||
config["transformer-dim-ffn"] = 4096;
|
||||
config["transformer-heads"] = 16;
|
||||
config["transformer-postprocess"] = "da"; // change from transformer-big is "dan" -> "da"
|
||||
config["transformer-preprocess"] = "n"; // change from transformer-big is "" -> "n"
|
||||
config["transformer-postprocess-top"] = "n"; // change from transformer-big is "" -> "n"
|
||||
config["transformer-ffn-activation"] = "relu";
|
||||
config["transformer-dropout"] = 0.1;
|
||||
|
||||
// Training specific options
|
||||
config["learn-rate"] = 0.0002;
|
||||
config["cost-type"] = "ce-mean-words";
|
||||
config["lr-warmup"] = 8000;
|
||||
config["lr-decay-inv-sqrt"] = 8000;
|
||||
config["label-smoothing"] = 0.1;
|
||||
config["clip-norm"] = 0;
|
||||
config["sync-sgd"] = true;
|
||||
config["exponential-smoothing"] = 1e-4;
|
||||
config["max-length"] = 100;
|
||||
config["mini-batch-fit"] = true;
|
||||
config["mini-batch"] = 1000;
|
||||
config["maxi-batch"] = 1000;
|
||||
config["workspace"] = 13000;
|
||||
config["optimizer-params"] = std::vector<float>({0.9f, 0.998f, 1e-09f});
|
||||
|
||||
// Validation specific options
|
||||
config["beam-size"] = 8;
|
||||
config["valid-mini-batch"] = 8;
|
||||
config["normalize"] = 1.0;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -288,6 +288,9 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
||||
cli.add<std::string>("--transformer-postprocess",
|
||||
"Operation after each transformer layer: d = dropout, a = add, n = normalize",
|
||||
"dan");
|
||||
cli.add<std::string>("--transformer-postprocess-top",
|
||||
"Final operation after a full transformer stack: d = dropout, a = add, n = normalize. The optional skip connection with 'a' by-passes the entire stack.",
|
||||
"");
|
||||
cli.add<bool>("--transformer-train-position-embeddings",
|
||||
"Train positional embeddings instead of using static sinusoidal embeddings");
|
||||
cli.add<bool>("--transformer-depth-scaling",
|
||||
|
@ -126,6 +126,7 @@ public:
|
||||
modelFeatures_.insert("transformer-preprocess");
|
||||
modelFeatures_.insert("transformer-postprocess");
|
||||
modelFeatures_.insert("transformer-postprocess-emb");
|
||||
modelFeatures_.insert("transformer-postprocess-top");
|
||||
modelFeatures_.insert("transformer-decoder-autoreg");
|
||||
modelFeatures_.insert("transformer-tied-layers");
|
||||
modelFeatures_.insert("transformer-guided-alignment-layer");
|
||||
|
@ -47,6 +47,7 @@ EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
modelFeatures_.insert("transformer-preprocess");
|
||||
modelFeatures_.insert("transformer-postprocess");
|
||||
modelFeatures_.insert("transformer-postprocess-emb");
|
||||
modelFeatures_.insert("transformer-postprocess-top");
|
||||
modelFeatures_.insert("transformer-decoder-autoreg");
|
||||
modelFeatures_.insert("transformer-tied-layers");
|
||||
modelFeatures_.insert("transformer-guided-alignment-layer");
|
||||
|
@ -135,6 +135,7 @@ public:
|
||||
modelFeatures_.insert("transformer-preprocess");
|
||||
modelFeatures_.insert("transformer-postprocess");
|
||||
modelFeatures_.insert("transformer-postprocess-emb");
|
||||
modelFeatures_.insert("transformer-postprocess-top");
|
||||
modelFeatures_.insert("transformer-decoder-autoreg");
|
||||
modelFeatures_.insert("transformer-tied-layers");
|
||||
modelFeatures_.insert("transformer-guided-alignment-layer");
|
||||
|
@ -558,6 +558,8 @@ public:
|
||||
auto layer = transposeTimeBatch(batchEmbeddings); // [beam depth=1, batch size, max length, vector dim]
|
||||
auto layerMask = transposeTimeBatch(batchMask); // [beam depth=1, batch size, max length, vector dim=1]
|
||||
|
||||
auto prevLayer = layer; // keep handle to untransformed embeddings, potentially used for a final skip connection
|
||||
|
||||
auto opsEmb = opt<std::string>("transformer-postprocess-emb");
|
||||
float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
|
||||
layer = preProcess(prefix_ + "_emb", opsEmb, layer, dropProb);
|
||||
@ -580,6 +582,12 @@ public:
|
||||
checkpoint(layer); // sets a manually specified checkpoint if gradient checkpointing is enabled, does nothing otherwise.
|
||||
}
|
||||
|
||||
// this allows to run a final layernorm operation after going through the transformer layer stack.
|
||||
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
|
||||
// it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested.
|
||||
auto opsTop = opt<std::string>("transformer-postprocess-top", "");
|
||||
layer = postProcess(prefix_ + "_top", opsTop, layer, prevLayer, dropProb);
|
||||
|
||||
// restore organization of batch and time steps. This is currently required
|
||||
// to make RNN-based decoders and beam search work with this. We are looking
|
||||
// into making this more natural.
|
||||
@ -706,6 +714,8 @@ public:
|
||||
// reorganize batch and timestep
|
||||
auto query = transposeTimeBatch(scaledEmbeddings); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
|
||||
|
||||
auto prevQuery = query; // keep handle to untransformed embeddings, potentially used for a final skip connection
|
||||
|
||||
auto opsEmb = opt<std::string>("transformer-postprocess-emb");
|
||||
float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
|
||||
|
||||
@ -841,6 +851,12 @@ public:
|
||||
checkpoint(query);
|
||||
}
|
||||
|
||||
// This allows to run a final layernorm operation after going through the transformer layer stack.
|
||||
// By default the operations are empty, but with prenorm (--transformer-preprocess n --transformer-postprocess da)
|
||||
// it is recommended to normalize here. Can also be used to add a skip connection from the very bottom if requested.
|
||||
auto opsTop = opt<std::string>("transformer-postprocess-top", "");
|
||||
query = postProcess(prefix_ + "_top", opsTop, query, prevQuery, dropProb);
|
||||
|
||||
auto decoderContext = transposeTimeBatch(query); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vector dim]
|
||||
|
||||
//************************************************************************//
|
||||
|
Loading…
Reference in New Issue
Block a user