Merged PR 20879: Adjustable ffn width and depth in transformer decoder

This commit is contained in:
Marcin Junczys-Dowmunt 2021-09-28 17:19:07 +00:00
parent d796a3c3b7
commit 03fe175876
3 changed files with 24 additions and 7 deletions

View File

@ -255,10 +255,16 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
cli.add<int>("--transformer-dim-ffn",
"Size of position-wise feed-forward network (transformer)",
2048);
2048);
cli.add<int>("--transformer-decoder-dim-ffn",
"Size of position-wise feed-forward network in decoder (transformer). Uses --transformer-dim-ffn if 0.",
0);
cli.add<int>("--transformer-ffn-depth",
"Depth of filters (transformer)",
2);
cli.add<int>("--transformer-decoder-ffn-depth",
"Depth of filters in decoder (transformer). Uses --transformer-ffn-depth if 0",
0);
cli.add<std::string>("--transformer-ffn-activation",
"Activation between filters: swish or relu (transformer)",
"swish");

View File

@ -38,7 +38,9 @@ EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options)
modelFeatures_.insert("transformer-heads");
modelFeatures_.insert("transformer-no-projection");
modelFeatures_.insert("transformer-dim-ffn");
modelFeatures_.insert("transformer-decoder-dim-ffn");
modelFeatures_.insert("transformer-ffn-depth");
modelFeatures_.insert("transformer-decoder-ffn-depth");
modelFeatures_.insert("transformer-ffn-activation");
modelFeatures_.insert("transformer-dim-aan");
modelFeatures_.insert("transformer-aan-depth");

View File

@ -400,7 +400,7 @@ public:
opt<int>("transformer-heads"), /*cache=*/false);
}
Expr LayerFFN(std::string prefix, Expr input) const {
Expr LayerFFN(std::string prefix, Expr input, bool isDecoder=false) const {
int dimModel = input->shape()[-1];
float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
@ -408,13 +408,22 @@ public:
auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb);
auto actName = opt<std::string>("transformer-ffn-activation");
int dimFfn = opt<int>("transformer-dim-ffn");
int depthFfn = opt<int>("transformer-ffn-depth");
float ffnDropProb
= inference_ ? 0 : opt<float>("transformer-dropout-ffn");
if(isDecoder) {
int decDimFfn = opt<int>("transformer-decoder-dim-ffn", 0);
if(decDimFfn != 0)
dimFfn = decDimFfn;
int decDepthFfn = opt<int>("transformer-decoder-ffn-depth", 0);
if(decDepthFfn != 0)
depthFfn = decDepthFfn;
}
ABORT_IF(depthFfn < 1, "Filter depth {} is smaller than 1", depthFfn);
float ffnDropProb = inference_ ? 0 : opt<float>("transformer-dropout-ffn");
auto initFn = inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f);
// the stack of FF layers
@ -861,7 +870,7 @@ public:
// remember decoder state
decoderStates.push_back(decoderState);
query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query, /*isDecoder=*/true); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
checkpoint(query);
}