mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 20879: Adjustable ffn width and depth in transformer decoder
This commit is contained in:
parent
d796a3c3b7
commit
03fe175876
@ -255,10 +255,16 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
||||
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
|
||||
cli.add<int>("--transformer-dim-ffn",
|
||||
"Size of position-wise feed-forward network (transformer)",
|
||||
2048);
|
||||
2048);
|
||||
cli.add<int>("--transformer-decoder-dim-ffn",
|
||||
"Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.",
|
||||
0);
|
||||
cli.add<int>("--transformer-ffn-depth",
|
||||
"Depth of filters (transformer)",
|
||||
2);
|
||||
cli.add<int>("--transformer-decoder-ffn-depth",
|
||||
"Depth of filters in decoder (transformer). Uses --transformer-ffn-depth if 0",
|
||||
0);
|
||||
cli.add<std::string>("--transformer-ffn-activation",
|
||||
"Activation between filters: swish or relu (transformer)",
|
||||
"swish");
|
||||
|
@ -38,7 +38,9 @@ EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
modelFeatures_.insert("transformer-heads");
|
||||
modelFeatures_.insert("transformer-no-projection");
|
||||
modelFeatures_.insert("transformer-dim-ffn");
|
||||
modelFeatures_.insert("transformer-decoder-dim-ffn");
|
||||
modelFeatures_.insert("transformer-ffn-depth");
|
||||
modelFeatures_.insert("transformer-decoder-ffn-depth");
|
||||
modelFeatures_.insert("transformer-ffn-activation");
|
||||
modelFeatures_.insert("transformer-dim-aan");
|
||||
modelFeatures_.insert("transformer-aan-depth");
|
||||
|
@ -400,7 +400,7 @@ public:
|
||||
opt<int>("transformer-heads"), /*cache=*/false);
|
||||
}
|
||||
|
||||
Expr LayerFFN(std::string prefix, Expr input) const {
|
||||
Expr LayerFFN(std::string prefix, Expr input, bool isDecoder=false) const {
|
||||
int dimModel = input->shape()[-1];
|
||||
|
||||
float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
|
||||
@ -408,13 +408,22 @@ public:
|
||||
auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb);
|
||||
|
||||
auto actName = opt<std::string>("transformer-ffn-activation");
|
||||
|
||||
int dimFfn = opt<int>("transformer-dim-ffn");
|
||||
int depthFfn = opt<int>("transformer-ffn-depth");
|
||||
float ffnDropProb
|
||||
= inference_ ? 0 : opt<float>("transformer-dropout-ffn");
|
||||
|
||||
if(isDecoder) {
|
||||
int decDimFfn = opt<int>("transformer-decoder-dim-ffn", 0);
|
||||
if(decDimFfn != 0)
|
||||
dimFfn = decDimFfn;
|
||||
|
||||
int decDepthFfn = opt<int>("transformer-decoder-ffn-depth", 0);
|
||||
if(decDepthFfn != 0)
|
||||
depthFfn = decDepthFfn;
|
||||
}
|
||||
|
||||
ABORT_IF(depthFfn < 1, "Filter depth {} is smaller than 1", depthFfn);
|
||||
|
||||
|
||||
float ffnDropProb = inference_ ? 0 : opt<float>("transformer-dropout-ffn");
|
||||
auto initFn = inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f);
|
||||
|
||||
// the stack of FF layers
|
||||
@ -861,7 +870,7 @@ public:
|
||||
// remember decoder state
|
||||
decoderStates.push_back(decoderState);
|
||||
|
||||
query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
|
||||
query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query, /*isDecoder=*/true); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
|
||||
|
||||
checkpoint(query);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user