rename bert-specific options

This commit is contained in:
Marcin Junczys-Dowmunt 2019-01-27 12:43:04 -08:00
parent cb8c249ec6
commit 83fbd248d0
7 changed files with 303 additions and 40 deletions

View File

@ -44,11 +44,12 @@ config["transformer-dim-ffn"] = tfModel["bert/encoder/layer_0/intermediate/dense
config["transformer-ffn-activation"] = bertConfig["hidden_act"]
config["transformer-ffn-depth"] = 2
config["transformer-heads"] = bertConfig["num_attention_heads"]
config["transformer-train-positions"] = True
config["transformer-token-types"] = tfModel["bert/embeddings/token_type_embeddings:0"].shape[0]
config["transformer-train-position-embeddings"] = True
config["transformer-preprocess"] = ""
config["transformer-postprocess"] = "dan"
config["transformer-postprocess-emb"] = "nd"
config["bert-train-type-embeddings"] = True
config["bert-type-vocab-size"] = tfModel["bert/embeddings/token_type_embeddings:0"].shape[0]
config["version"] = "bert4marian.py conversion"
# check number of layers
@ -78,7 +79,7 @@ marianModel["special:model.yml"] = npDesc
# Embedding layers
marianModel["Wemb"] = tfModel["bert/embeddings/word_embeddings:0"]
marianModel["Wpos"] = tfModel["bert/embeddings/position_embeddings:0"]
marianModel["Wsent"] = tfModel["bert/embeddings/token_type_embeddings:0"]
marianModel["Wtype"] = tfModel["bert/embeddings/token_type_embeddings:0"]
marianModel["encoder_emb_ln_scale_pre"] = tfModel["bert/embeddings/LayerNorm/gamma:0"]
marianModel["encoder_emb_ln_bias_pre"] = tfModel["bert/embeddings/LayerNorm/beta:0"]

View File

@ -200,13 +200,16 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
cli.add<std::string>("--transformer-postprocess",
"Operation after each transformer layer: d = dropout, a = add, n = normalize",
"dan");
cli.add<bool>("--transformer-train-positions",
cli.add<bool>("--transformer-train-position-embeddings",
"Train positional embeddings instead of using static sinusoidal embeddings");
cli.add<std::string>("--bert-mask-symbol", "Masking symbol for BERT masked-LM training", "[MASK]");
cli.add<std::string>("--bert-sep-symbol", "Sentence separator symbol for BERT next sentence prediction training", "[SEP]");
cli.add<std::string>("--bert-class-symbol", "Class symbol BERT classifier training", "[CLS]");
cli.add<float>("--bert-masking-fraction", "Fraction of masked out tokens during training", 0.15);
cli.add<bool>("--bert-train-type-embeddings", "Train bert type embeddings, set to false to use static sinusoidal embeddings", true);
cli.add<int>("--bert-type-vocab-size", "Size of BERT type vocab (sentence A and B)", 2);
#ifdef CUDNN
cli.add<int>("--char-stride",
"Width of max-pooling layer after convolution layer in char-s2s model",

View File

@ -70,7 +70,8 @@ public:
float maskFraction,
const std::string& maskSymbol,
const std::string& sepSymbol,
const std::string& clsSymbol)
const std::string& clsSymbol,
int dimTypeVocab)
: CorpusBatch(*batch),
maskSymbol_(maskSymbol), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) {
@ -119,18 +120,19 @@ public:
words[i] = maskOut(words[i], maskId, engine); // mask that position
}
annotateSentenceIndices();
annotateSentenceIndices(dimTypeVocab);
}
BertBatch(Ptr<CorpusBatch> batch,
const std::string& sepSymbol,
const std::string& clsSymbol)
const std::string& clsSymbol,
int dimTypeVocab)
: CorpusBatch(*batch),
maskSymbol_("dummy"), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) {
annotateSentenceIndices();
annotateSentenceIndices(dimTypeVocab);
}
void annotateSentenceIndices() {
void annotateSentenceIndices(int dimTypeVocab) {
// BERT expects a textual first stream and a second stream with class labels
auto subBatch = subBatches_.front();
const auto& vocab = *subBatch->vocab();
@ -144,7 +146,7 @@ public:
int dimBatch = subBatch->batchSize();
int dimWords = subBatch->batchWidth();
int maxSentPos = 1; // Currently only two sentences allowed A at [0] and B at [1] and padding at [2]
int maxSentPos = dimTypeVocab; // Currently only two sentences allowed A at [0] and B at [1] and padding at [2]
// If another separator is seen do not increase position index beyond 2 but use padding.
// @TODO: make this configurable, see below for NextSentencePredictions task where we also restrict to 2.
@ -180,6 +182,7 @@ public:
std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
std::string modelType = opt<std::string>("type");
int dimTypeVocab = opt<int>("bert-type-vocab-size");
// intercept batch and annotate with BERT-specific concepts
Ptr<data::BertBatch> bertBatch;
@ -189,11 +192,13 @@ public:
opt<float>("bert-masking-fraction", 0.15f), // 15% by default according to paper
opt<std::string>("bert-mask-symbol"),
opt<std::string>("bert-sep-symbol"),
opt<std::string>("bert-class-symbol"));
opt<std::string>("bert-class-symbol"),
dimTypeVocab);
} else if(modelType == "bert-classifier") { // we are probably fine-tuning a BERT model for a classification task
bertBatch = New<data::BertBatch>(batch,
opt<std::string>("bert-sep-symbol"),
opt<std::string>("bert-class-symbol")); // only annotate sentence separators
opt<std::string>("bert-class-symbol"),
dimTypeVocab); // only annotate sentence separators
} else {
ABORT("Unknown BERT-style model: {}", modelType);
}
@ -219,7 +224,6 @@ public:
Expr addSentenceEmbeddings(Expr embeddings,
Ptr<data::CorpusBatch> batch,
bool learnedPosEmbeddings) const {
Ptr<data::BertBatch> bertBatch = std::dynamic_pointer_cast<data::BertBatch>(batch);
ABORT_IF(!bertBatch, "Batch must be BertBatch for BERT training or fine-tuning");
@ -227,11 +231,13 @@ public:
int dimBatch = embeddings->shape()[-2];
int dimWords = embeddings->shape()[-3];
int dimTypeVocab = opt<int>("bert-type-vocab-size", 2);
Expr signal;
if(learnedPosEmbeddings) {
auto sentenceEmbeddings = embedding()
("prefix", "Wsent")
("dimVocab", 2) // sentence A or sentence B plus padding, @TODO: should rather be a parameter
("prefix", "Wtype")
("dimVocab", dimTypeVocab) // sentence A or sentence B
("dimEmb", dimEmb)
.construct(graph_);
signal = sentenceEmbeddings->apply(bertBatch->bertSentenceIndices(), {dimWords, dimBatch, dimEmb});
@ -247,9 +253,10 @@ public:
}
virtual Expr addSpecialEmbeddings(Expr input, int start = 0, Ptr<data::CorpusBatch> batch = nullptr) const override {
bool trainPosEmbeddings = opt<bool>("transformer-train-positions", true);
bool trainPosEmbeddings = opt<bool>("transformer-train-position-embeddings", true);
bool trainTypeEmbeddings = opt<bool>("bert-train-type-embeddings", true);
input = addPositionalEmbeddings(input, start, trainPosEmbeddings);
input = addSentenceEmbeddings(input, batch, trainPosEmbeddings); // @TODO: separately set learnable pos and sent embeddings
input = addSentenceEmbeddings(input, batch, trainTypeEmbeddings);
return input;
}
};

View File

@ -92,6 +92,7 @@ public:
: options_(options),
prefix_(options->get<std::string>("prefix", "")),
inference_(options->get<bool>("inference", false)) {
modelFeatures_ = {"type",
"dim-vocabs",
"dim-emb",
@ -128,7 +129,9 @@ public:
modelFeatures_.insert("transformer-decoder-autoreg");
modelFeatures_.insert("transformer-tied-layers");
modelFeatures_.insert("transformer-guided-alignment-layer");
modelFeatures_.insert("transformer-train-positions");
modelFeatures_.insert("transformer-train-position-embeddings");
modelFeatures_.insert("bert-train-type-embeddings");
modelFeatures_.insert("bert-type-vocab-size");
}
virtual Ptr<Options> getOptions() override { return options_; }

View File

@ -8,26 +8,31 @@ EncoderDecoder::EncoderDecoder(Ptr<Options> options)
: options_(options),
prefix_(options->get<std::string>("prefix", "")),
inference_(options->get<bool>("inference", false)) {
modelFeatures_ = {"type",
"dim-vocabs",
"dim-emb",
"dim-rnn",
"enc-cell",
"enc-type",
"enc-cell-depth",
"enc-depth",
"dec-depth",
"dec-cell",
"dec-cell-base-depth",
"dec-cell-high-depth",
"skip",
"layer-normalization",
"right-left",
"input-types",
"special-vocab",
"tied-embeddings",
"tied-embeddings-src",
"tied-embeddings-all"};
std::vector<std::string> encoderDecoderModelFeatures =
{"type",
"dim-vocabs",
"dim-emb",
"dim-rnn",
"enc-cell",
"enc-type",
"enc-cell-depth",
"enc-depth",
"dec-depth",
"dec-cell",
"dec-cell-base-depth",
"dec-cell-high-depth",
"skip",
"layer-normalization",
"right-left",
"input-types",
"special-vocab",
"tied-embeddings",
"tied-embeddings-src",
"tied-embeddings-all"};
for(auto feature : encoderDecoderModelFeatures)
modelFeatures_.insert(feature);
modelFeatures_.insert("transformer-heads");
modelFeatures_.insert("transformer-no-projection");
@ -44,7 +49,9 @@ EncoderDecoder::EncoderDecoder(Ptr<Options> options)
modelFeatures_.insert("transformer-decoder-autoreg");
modelFeatures_.insert("transformer-tied-layers");
modelFeatures_.insert("transformer-guided-alignment-layer");
modelFeatures_.insert("transformer-train-positions");
modelFeatures_.insert("transformer-train-position-embeddings");
modelFeatures_.insert("bert-train-type-embeddings");
modelFeatures_.insert("bert-type-vocab-size");
}
std::vector<Ptr<EncoderBase>>& EncoderDecoder::getEncoders() {

View File

@ -359,7 +359,8 @@ protected:
options_->get<float>("bert-masking-fraction"),
options_->get<std::string>("bert-mask-symbol"),
options_->get<std::string>("bert-sep-symbol"),
options_->get<std::string>("bert-class-symbol"));
options_->get<std::string>("bert-class-symbol"),
options_->get<int>("bert-type-vocab-size"));
builder->clear(graph);
auto classifierStates = std::dynamic_pointer_cast<BertEncoderClassifier>(builder)->apply(graph, bertBatch, true);

View File

@ -0,0 +1,241 @@
#pragma once
#include <algorithm>
#include "marian.h"
#include "translator/history.h"
#include "translator/scorers.h"
#include "translator/helpers.h"
#include "translator/nth_element.h"
namespace marian {
class Classification {
private:
Ptr<Options> options_;
std::vector<Ptr<Scorer>> scorers_;
size_t topN_{1};
public:
Classification(Ptr<Options> options)
: options_(options),
scorers_(scorers),
topN_{options_->get<size_t>("beam-size")} // misuse beam-size for topN display
{}
Beams toHyps(const std::vector<unsigned int> keys,
const std::vector<float> pathScores,
size_t labelNum,
const Beams& beams,
std::vector<Ptr<ScorerState>>& states,
size_t topN,
bool first,
Ptr<data::CorpusBatch> batch) {
Beams newBeams(beams.size());
std::vector<float> align;
if(options_->has("alignment"))
// Use alignments from the first scorer, even if ensemble
align = scorers_[0]->getAlignment();
for(size_t i = 0; i < keys.size(); ++i) {
// Keys contains indices to vocab items in the entire beam.
// Values can be between 0 and topN * number of lables.
Word embIdx = (Word)(keys[i] % labelNum);
auto beamIdx = i / topN;
// Retrieve short list for final softmax (based on words aligned
// to source sentences). If short list has been set, map the indices
// in the sub-selected vocabulary matrix back to their original positions.
auto shortlist = scorers_[0]->getShortlist();
if(shortlist)
embIdx = shortlist->reverseMap(embIdx); // @TODO: should reverseMap accept a size_t or a Word?
if(newBeams[beamIdx].size() < beams[beamIdx].size()) {
auto& beam = beams[beamIdx];
auto& newBeam = newBeams[beamIdx];
auto hypIdx = (IndexType)(keys[i] / labelNum);
float pathScore = pathScores[i];
auto hypIdxTrans
= IndexType((hypIdx / topN) + (hypIdx % topN) * beams.size());
if(first)
hypIdxTrans = hypIdx;
size_t beamHypIdx = hypIdx % topN;
if(beamHypIdx >= (int)beam.size())
beamHypIdx = beamHypIdx % beam.size();
if(first)
beamHypIdx = 0;
auto hyp = New<Hypothesis>(beam[beamHypIdx], embIdx, hypIdxTrans, pathScore);
// Set score breakdown for n-best lists
if(options_->get<bool>("n-best")) {
std::vector<float> breakDown(states.size(), 0);
beam[beamHypIdx]->GetScoreBreakdown().resize(states.size(), 0);
for(size_t j = 0; j < states.size(); ++j) {
size_t key = embIdx + hypIdxTrans * labelNum;
breakDown[j] = states[j]->breakDown(key)
+ beam[beamHypIdx]->GetScoreBreakdown()[j];
}
hyp->GetScoreBreakdown() = breakDown;
}
newBeam.push_back(hyp);
}
}
return newBeams;
}
// main decoding function
Histories search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
int dimBatch = (int)batch->size();
Histories histories;
for(int i = 0; i < dimBatch; ++i) {
size_t sentId = batch->getSentenceIds()[i];
auto history = New<History>(sentId,
options_->get<float>("normalize"),
options_->get<float>("word-penalty"));
histories.push_back(history);
}
auto getNBestList = createGetNBestListFn(topN_, dimBatch, graph->getDeviceId());
Beams beams(dimBatch); // [batchIndex][beamIndex] is one sentence hypothesis
for(auto& beam : beams)
beam.resize(topN_, New<Hypothesis>());
bool first = true;
bool final = false;
for(int i = 0; i < dimBatch; ++i)
histories[i]->Add(beams[i], trgEosId_);
std::vector<Ptr<ScorerState>> states;
for(auto scorer : scorers_) {
scorer->clear(graph);
}
for(auto scorer : scorers_) {
states.push_back(scorer->apply(graph, batch));
}
// main loop over output tokens
do {
//**********************************************************************
// create constant containing previous path scores for current beam
// also create mapping of hyp indices, which are not 1:1 if sentences complete
std::vector<IndexType> hypIndices; // [beamIndex * activeBatchSize + batchIndex] backpointers, concatenated over beam positions. Used for reordering hypotheses
std::vector<IndexType> embIndices;
Expr prevPathScores; // [beam, 1, 1, 1]
if(first) {
// no scores yet
prevPathScores = graph->constant({1, 1, 1, 1}, inits::from_value(0));
} else {
std::vector<float> beamScores;
dimBatch = (int)batch->size();
for(size_t i = 0; i < localBeamSize; ++i) {
for(size_t j = 0; j < beams.size(); ++j) { // loop over batch entries (active sentences)
auto& beam = beams[j];
if(i < beam.size()) {
auto hyp = beam[i];
hypIndices.push_back((IndexType)hyp->GetPrevStateIndex()); // backpointer
embIndices.push_back(hyp->GetWord());
beamScores.push_back(hyp->GetPathScore());
} else { // dummy hypothesis
hypIndices.push_back(0);
embIndices.push_back(0); // (unused)
beamScores.push_back(-9999);
}
}
}
prevPathScores = graph->constant({(int)localBeamSize, 1, dimBatch, 1},
inits::from_vector(beamScores));
}
//**********************************************************************
// prepare scores for beam search
auto pathScores = prevPathScores;
for(size_t i = 0; i < scorers_.size(); ++i) {
states[i] = scorers_[i]->step(
graph, states[i], hypIndices, embIndices, dimBatch, (int)localBeamSize);
if(scorers_[i]->getWeight() != 1.f)
pathScores = pathScores + scorers_[i]->getWeight() * states[i]->getLogProbs();
else
pathScores = pathScores + states[i]->getLogProbs();
}
// make beams continuous
if(dimBatch > 1 && localBeamSize > 1)
pathScores = transpose(pathScores, {2, 1, 0, 3}); // check if this is needed for classification, rather not, beamSize and topN is badly defined here
if(first)
graph->forward();
else
graph->forwardNext();
//**********************************************************************
// suppress specific symbols if not at right positions
if(trgUnkId_ != -1 && options_->has("allow-unk")
&& !options_->get<bool>("allow-unk"))
suppressWord(pathScores, trgUnkId_);
for(auto state : states)
state->blacklist(pathScores, batch);
//**********************************************************************
// perform beam search and pruning
std::vector<unsigned int> outKeys;
std::vector<float> outPathScores;
std::vector<size_t> beamSizes(dimBatch, localBeamSize);
getNBestList(beamSizes, pathScores->val(), outPathScores, outKeys, first);
int dimTrgVoc = pathScores->shape()[-1];
beams = toHyps(outKeys,
outPathScores,
dimTrgVoc,
beams,
states,
localBeamSize,
first,
batch);
auto prunedBeams = pruneBeam(beams);
for(int i = 0; i < dimBatch; ++i) {
if(!beams[i].empty()) {
final = final
|| histories[i]->size()
>= options_->get<float>("max-length-factor")
* batch->front()->batchWidth();
histories[i]->Add(
beams[i], trgEosId_, prunedBeams[i].empty() || final);
}
}
beams = prunedBeams;
// determine beam size for next sentence, as max over still-active sentences
if(!first) {
size_t maxBeam = 0;
for(auto& beam : beams)
if(beam.size() > maxBeam)
maxBeam = beam.size();
localBeamSize = maxBeam;
}
first = false;
} while(localBeamSize != 0 && !final); // end of main loop over output tokens
return histories;
}
};
} // namespace marian