mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
rename bert-specific options
This commit is contained in:
parent
cb8c249ec6
commit
83fbd248d0
@ -44,11 +44,12 @@ config["transformer-dim-ffn"] = tfModel["bert/encoder/layer_0/intermediate/dense
|
||||
config["transformer-ffn-activation"] = bertConfig["hidden_act"]
|
||||
config["transformer-ffn-depth"] = 2
|
||||
config["transformer-heads"] = bertConfig["num_attention_heads"]
|
||||
config["transformer-train-positions"] = True
|
||||
config["transformer-token-types"] = tfModel["bert/embeddings/token_type_embeddings:0"].shape[0]
|
||||
config["transformer-train-position-embeddings"] = True
|
||||
config["transformer-preprocess"] = ""
|
||||
config["transformer-postprocess"] = "dan"
|
||||
config["transformer-postprocess-emb"] = "nd"
|
||||
config["bert-train-type-embeddings"] = True
|
||||
config["bert-type-vocab-size"] = tfModel["bert/embeddings/token_type_embeddings:0"].shape[0]
|
||||
config["version"] = "bert4marian.py conversion"
|
||||
|
||||
# check number of layers
|
||||
@ -78,7 +79,7 @@ marianModel["special:model.yml"] = npDesc
|
||||
# Embedding layers
|
||||
marianModel["Wemb"] = tfModel["bert/embeddings/word_embeddings:0"]
|
||||
marianModel["Wpos"] = tfModel["bert/embeddings/position_embeddings:0"]
|
||||
marianModel["Wsent"] = tfModel["bert/embeddings/token_type_embeddings:0"]
|
||||
marianModel["Wtype"] = tfModel["bert/embeddings/token_type_embeddings:0"]
|
||||
marianModel["encoder_emb_ln_scale_pre"] = tfModel["bert/embeddings/LayerNorm/gamma:0"]
|
||||
marianModel["encoder_emb_ln_bias_pre"] = tfModel["bert/embeddings/LayerNorm/beta:0"]
|
||||
|
||||
|
@ -200,13 +200,16 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
|
||||
cli.add<std::string>("--transformer-postprocess",
|
||||
"Operation after each transformer layer: d = dropout, a = add, n = normalize",
|
||||
"dan");
|
||||
cli.add<bool>("--transformer-train-positions",
|
||||
cli.add<bool>("--transformer-train-position-embeddings",
|
||||
"Train positional embeddings instead of using static sinusoidal embeddings");
|
||||
|
||||
cli.add<std::string>("--bert-mask-symbol", "Masking symbol for BERT masked-LM training", "[MASK]");
|
||||
cli.add<std::string>("--bert-sep-symbol", "Sentence separator symbol for BERT next sentence prediction training", "[SEP]");
|
||||
cli.add<std::string>("--bert-class-symbol", "Class symbol BERT classifier training", "[CLS]");
|
||||
cli.add<float>("--bert-masking-fraction", "Fraction of masked out tokens during training", 0.15);
|
||||
cli.add<bool>("--bert-train-type-embeddings", "Train bert type embeddings, set to false to use static sinusoidal embeddings", true);
|
||||
cli.add<int>("--bert-type-vocab-size", "Size of BERT type vocab (sentence A and B)", 2);
|
||||
|
||||
#ifdef CUDNN
|
||||
cli.add<int>("--char-stride",
|
||||
"Width of max-pooling layer after convolution layer in char-s2s model",
|
||||
|
@ -70,7 +70,8 @@ public:
|
||||
float maskFraction,
|
||||
const std::string& maskSymbol,
|
||||
const std::string& sepSymbol,
|
||||
const std::string& clsSymbol)
|
||||
const std::string& clsSymbol,
|
||||
int dimTypeVocab)
|
||||
: CorpusBatch(*batch),
|
||||
maskSymbol_(maskSymbol), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) {
|
||||
|
||||
@ -119,18 +120,19 @@ public:
|
||||
words[i] = maskOut(words[i], maskId, engine); // mask that position
|
||||
}
|
||||
|
||||
annotateSentenceIndices();
|
||||
annotateSentenceIndices(dimTypeVocab);
|
||||
}
|
||||
|
||||
BertBatch(Ptr<CorpusBatch> batch,
|
||||
const std::string& sepSymbol,
|
||||
const std::string& clsSymbol)
|
||||
const std::string& clsSymbol,
|
||||
int dimTypeVocab)
|
||||
: CorpusBatch(*batch),
|
||||
maskSymbol_("dummy"), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) {
|
||||
annotateSentenceIndices();
|
||||
annotateSentenceIndices(dimTypeVocab);
|
||||
}
|
||||
|
||||
void annotateSentenceIndices() {
|
||||
void annotateSentenceIndices(int dimTypeVocab) {
|
||||
// BERT expects a textual first stream and a second stream with class labels
|
||||
auto subBatch = subBatches_.front();
|
||||
const auto& vocab = *subBatch->vocab();
|
||||
@ -144,7 +146,7 @@ public:
|
||||
int dimBatch = subBatch->batchSize();
|
||||
int dimWords = subBatch->batchWidth();
|
||||
|
||||
int maxSentPos = 1; // Currently only two sentences allowed A at [0] and B at [1] and padding at [2]
|
||||
int maxSentPos = dimTypeVocab; // Currently only two sentences allowed A at [0] and B at [1] and padding at [2]
|
||||
// If another separator is seen do not increase position index beyond 2 but use padding.
|
||||
// @TODO: make this configurable, see below for NextSentencePredictions task where we also restrict to 2.
|
||||
|
||||
@ -180,6 +182,7 @@ public:
|
||||
|
||||
std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
|
||||
std::string modelType = opt<std::string>("type");
|
||||
int dimTypeVocab = opt<int>("bert-type-vocab-size");
|
||||
|
||||
// intercept batch and annotate with BERT-specific concepts
|
||||
Ptr<data::BertBatch> bertBatch;
|
||||
@ -189,11 +192,13 @@ public:
|
||||
opt<float>("bert-masking-fraction", 0.15f), // 15% by default according to paper
|
||||
opt<std::string>("bert-mask-symbol"),
|
||||
opt<std::string>("bert-sep-symbol"),
|
||||
opt<std::string>("bert-class-symbol"));
|
||||
opt<std::string>("bert-class-symbol"),
|
||||
dimTypeVocab);
|
||||
} else if(modelType == "bert-classifier") { // we are probably fine-tuning a BERT model for a classification task
|
||||
bertBatch = New<data::BertBatch>(batch,
|
||||
opt<std::string>("bert-sep-symbol"),
|
||||
opt<std::string>("bert-class-symbol")); // only annotate sentence separators
|
||||
opt<std::string>("bert-class-symbol"),
|
||||
dimTypeVocab); // only annotate sentence separators
|
||||
} else {
|
||||
ABORT("Unknown BERT-style model: {}", modelType);
|
||||
}
|
||||
@ -219,7 +224,6 @@ public:
|
||||
Expr addSentenceEmbeddings(Expr embeddings,
|
||||
Ptr<data::CorpusBatch> batch,
|
||||
bool learnedPosEmbeddings) const {
|
||||
|
||||
Ptr<data::BertBatch> bertBatch = std::dynamic_pointer_cast<data::BertBatch>(batch);
|
||||
ABORT_IF(!bertBatch, "Batch must be BertBatch for BERT training or fine-tuning");
|
||||
|
||||
@ -227,11 +231,13 @@ public:
|
||||
int dimBatch = embeddings->shape()[-2];
|
||||
int dimWords = embeddings->shape()[-3];
|
||||
|
||||
int dimTypeVocab = opt<int>("bert-type-vocab-size", 2);
|
||||
|
||||
Expr signal;
|
||||
if(learnedPosEmbeddings) {
|
||||
auto sentenceEmbeddings = embedding()
|
||||
("prefix", "Wsent")
|
||||
("dimVocab", 2) // sentence A or sentence B plus padding, @TODO: should rather be a parameter
|
||||
("prefix", "Wtype")
|
||||
("dimVocab", dimTypeVocab) // sentence A or sentence B
|
||||
("dimEmb", dimEmb)
|
||||
.construct(graph_);
|
||||
signal = sentenceEmbeddings->apply(bertBatch->bertSentenceIndices(), {dimWords, dimBatch, dimEmb});
|
||||
@ -247,9 +253,10 @@ public:
|
||||
}
|
||||
|
||||
virtual Expr addSpecialEmbeddings(Expr input, int start = 0, Ptr<data::CorpusBatch> batch = nullptr) const override {
|
||||
bool trainPosEmbeddings = opt<bool>("transformer-train-positions", true);
|
||||
bool trainPosEmbeddings = opt<bool>("transformer-train-position-embeddings", true);
|
||||
bool trainTypeEmbeddings = opt<bool>("bert-train-type-embeddings", true);
|
||||
input = addPositionalEmbeddings(input, start, trainPosEmbeddings);
|
||||
input = addSentenceEmbeddings(input, batch, trainPosEmbeddings); // @TODO: separately set learnable pos and sent embeddings
|
||||
input = addSentenceEmbeddings(input, batch, trainTypeEmbeddings);
|
||||
return input;
|
||||
}
|
||||
};
|
||||
|
@ -92,6 +92,7 @@ public:
|
||||
: options_(options),
|
||||
prefix_(options->get<std::string>("prefix", "")),
|
||||
inference_(options->get<bool>("inference", false)) {
|
||||
|
||||
modelFeatures_ = {"type",
|
||||
"dim-vocabs",
|
||||
"dim-emb",
|
||||
@ -128,7 +129,9 @@ public:
|
||||
modelFeatures_.insert("transformer-decoder-autoreg");
|
||||
modelFeatures_.insert("transformer-tied-layers");
|
||||
modelFeatures_.insert("transformer-guided-alignment-layer");
|
||||
modelFeatures_.insert("transformer-train-positions");
|
||||
modelFeatures_.insert("transformer-train-position-embeddings");
|
||||
modelFeatures_.insert("bert-train-type-embeddings");
|
||||
modelFeatures_.insert("bert-type-vocab-size");
|
||||
}
|
||||
|
||||
virtual Ptr<Options> getOptions() override { return options_; }
|
||||
|
@ -8,26 +8,31 @@ EncoderDecoder::EncoderDecoder(Ptr<Options> options)
|
||||
: options_(options),
|
||||
prefix_(options->get<std::string>("prefix", "")),
|
||||
inference_(options->get<bool>("inference", false)) {
|
||||
modelFeatures_ = {"type",
|
||||
"dim-vocabs",
|
||||
"dim-emb",
|
||||
"dim-rnn",
|
||||
"enc-cell",
|
||||
"enc-type",
|
||||
"enc-cell-depth",
|
||||
"enc-depth",
|
||||
"dec-depth",
|
||||
"dec-cell",
|
||||
"dec-cell-base-depth",
|
||||
"dec-cell-high-depth",
|
||||
"skip",
|
||||
"layer-normalization",
|
||||
"right-left",
|
||||
"input-types",
|
||||
"special-vocab",
|
||||
"tied-embeddings",
|
||||
"tied-embeddings-src",
|
||||
"tied-embeddings-all"};
|
||||
|
||||
std::vector<std::string> encoderDecoderModelFeatures =
|
||||
{"type",
|
||||
"dim-vocabs",
|
||||
"dim-emb",
|
||||
"dim-rnn",
|
||||
"enc-cell",
|
||||
"enc-type",
|
||||
"enc-cell-depth",
|
||||
"enc-depth",
|
||||
"dec-depth",
|
||||
"dec-cell",
|
||||
"dec-cell-base-depth",
|
||||
"dec-cell-high-depth",
|
||||
"skip",
|
||||
"layer-normalization",
|
||||
"right-left",
|
||||
"input-types",
|
||||
"special-vocab",
|
||||
"tied-embeddings",
|
||||
"tied-embeddings-src",
|
||||
"tied-embeddings-all"};
|
||||
|
||||
for(auto feature : encoderDecoderModelFeatures)
|
||||
modelFeatures_.insert(feature);
|
||||
|
||||
modelFeatures_.insert("transformer-heads");
|
||||
modelFeatures_.insert("transformer-no-projection");
|
||||
@ -44,7 +49,9 @@ EncoderDecoder::EncoderDecoder(Ptr<Options> options)
|
||||
modelFeatures_.insert("transformer-decoder-autoreg");
|
||||
modelFeatures_.insert("transformer-tied-layers");
|
||||
modelFeatures_.insert("transformer-guided-alignment-layer");
|
||||
modelFeatures_.insert("transformer-train-positions");
|
||||
modelFeatures_.insert("transformer-train-position-embeddings");
|
||||
modelFeatures_.insert("bert-train-type-embeddings");
|
||||
modelFeatures_.insert("bert-type-vocab-size");
|
||||
}
|
||||
|
||||
std::vector<Ptr<EncoderBase>>& EncoderDecoder::getEncoders() {
|
||||
|
@ -359,7 +359,8 @@ protected:
|
||||
options_->get<float>("bert-masking-fraction"),
|
||||
options_->get<std::string>("bert-mask-symbol"),
|
||||
options_->get<std::string>("bert-sep-symbol"),
|
||||
options_->get<std::string>("bert-class-symbol"));
|
||||
options_->get<std::string>("bert-class-symbol"),
|
||||
options_->get<int>("bert-type-vocab-size"));
|
||||
|
||||
builder->clear(graph);
|
||||
auto classifierStates = std::dynamic_pointer_cast<BertEncoderClassifier>(builder)->apply(graph, bertBatch, true);
|
||||
|
241
src/translator/classification.h
Normal file
241
src/translator/classification.h
Normal file
@ -0,0 +1,241 @@
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
|
||||
#include "marian.h"
|
||||
#include "translator/history.h"
|
||||
#include "translator/scorers.h"
|
||||
|
||||
#include "translator/helpers.h"
|
||||
#include "translator/nth_element.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
class Classification {
|
||||
private:
|
||||
Ptr<Options> options_;
|
||||
std::vector<Ptr<Scorer>> scorers_;
|
||||
size_t topN_{1};
|
||||
|
||||
public:
|
||||
Classification(Ptr<Options> options)
|
||||
: options_(options),
|
||||
scorers_(scorers),
|
||||
topN_{options_->get<size_t>("beam-size")} // misuse beam-size for topN display
|
||||
{}
|
||||
|
||||
Beams toHyps(const std::vector<unsigned int> keys,
|
||||
const std::vector<float> pathScores,
|
||||
size_t labelNum,
|
||||
const Beams& beams,
|
||||
std::vector<Ptr<ScorerState>>& states,
|
||||
size_t topN,
|
||||
bool first,
|
||||
Ptr<data::CorpusBatch> batch) {
|
||||
Beams newBeams(beams.size());
|
||||
|
||||
std::vector<float> align;
|
||||
if(options_->has("alignment"))
|
||||
// Use alignments from the first scorer, even if ensemble
|
||||
align = scorers_[0]->getAlignment();
|
||||
|
||||
for(size_t i = 0; i < keys.size(); ++i) {
|
||||
// Keys contains indices to vocab items in the entire beam.
|
||||
// Values can be between 0 and topN * number of lables.
|
||||
Word embIdx = (Word)(keys[i] % labelNum);
|
||||
auto beamIdx = i / topN;
|
||||
|
||||
// Retrieve short list for final softmax (based on words aligned
|
||||
// to source sentences). If short list has been set, map the indices
|
||||
// in the sub-selected vocabulary matrix back to their original positions.
|
||||
auto shortlist = scorers_[0]->getShortlist();
|
||||
if(shortlist)
|
||||
embIdx = shortlist->reverseMap(embIdx); // @TODO: should reverseMap accept a size_t or a Word?
|
||||
|
||||
if(newBeams[beamIdx].size() < beams[beamIdx].size()) {
|
||||
auto& beam = beams[beamIdx];
|
||||
auto& newBeam = newBeams[beamIdx];
|
||||
|
||||
auto hypIdx = (IndexType)(keys[i] / labelNum);
|
||||
float pathScore = pathScores[i];
|
||||
|
||||
auto hypIdxTrans
|
||||
= IndexType((hypIdx / topN) + (hypIdx % topN) * beams.size());
|
||||
if(first)
|
||||
hypIdxTrans = hypIdx;
|
||||
|
||||
size_t beamHypIdx = hypIdx % topN;
|
||||
if(beamHypIdx >= (int)beam.size())
|
||||
beamHypIdx = beamHypIdx % beam.size();
|
||||
|
||||
if(first)
|
||||
beamHypIdx = 0;
|
||||
|
||||
auto hyp = New<Hypothesis>(beam[beamHypIdx], embIdx, hypIdxTrans, pathScore);
|
||||
|
||||
// Set score breakdown for n-best lists
|
||||
if(options_->get<bool>("n-best")) {
|
||||
std::vector<float> breakDown(states.size(), 0);
|
||||
beam[beamHypIdx]->GetScoreBreakdown().resize(states.size(), 0);
|
||||
for(size_t j = 0; j < states.size(); ++j) {
|
||||
size_t key = embIdx + hypIdxTrans * labelNum;
|
||||
breakDown[j] = states[j]->breakDown(key)
|
||||
+ beam[beamHypIdx]->GetScoreBreakdown()[j];
|
||||
}
|
||||
hyp->GetScoreBreakdown() = breakDown;
|
||||
}
|
||||
|
||||
newBeam.push_back(hyp);
|
||||
}
|
||||
}
|
||||
return newBeams;
|
||||
}
|
||||
|
||||
// main decoding function
|
||||
Histories search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
|
||||
int dimBatch = (int)batch->size();
|
||||
|
||||
Histories histories;
|
||||
for(int i = 0; i < dimBatch; ++i) {
|
||||
size_t sentId = batch->getSentenceIds()[i];
|
||||
auto history = New<History>(sentId,
|
||||
options_->get<float>("normalize"),
|
||||
options_->get<float>("word-penalty"));
|
||||
histories.push_back(history);
|
||||
}
|
||||
|
||||
auto getNBestList = createGetNBestListFn(topN_, dimBatch, graph->getDeviceId());
|
||||
|
||||
Beams beams(dimBatch); // [batchIndex][beamIndex] is one sentence hypothesis
|
||||
for(auto& beam : beams)
|
||||
beam.resize(topN_, New<Hypothesis>());
|
||||
|
||||
bool first = true;
|
||||
bool final = false;
|
||||
|
||||
for(int i = 0; i < dimBatch; ++i)
|
||||
histories[i]->Add(beams[i], trgEosId_);
|
||||
|
||||
std::vector<Ptr<ScorerState>> states;
|
||||
|
||||
for(auto scorer : scorers_) {
|
||||
scorer->clear(graph);
|
||||
}
|
||||
|
||||
for(auto scorer : scorers_) {
|
||||
states.push_back(scorer->apply(graph, batch));
|
||||
}
|
||||
|
||||
// main loop over output tokens
|
||||
do {
|
||||
//**********************************************************************
|
||||
// create constant containing previous path scores for current beam
|
||||
// also create mapping of hyp indices, which are not 1:1 if sentences complete
|
||||
std::vector<IndexType> hypIndices; // [beamIndex * activeBatchSize + batchIndex] backpointers, concatenated over beam positions. Used for reordering hypotheses
|
||||
std::vector<IndexType> embIndices;
|
||||
Expr prevPathScores; // [beam, 1, 1, 1]
|
||||
if(first) {
|
||||
// no scores yet
|
||||
prevPathScores = graph->constant({1, 1, 1, 1}, inits::from_value(0));
|
||||
} else {
|
||||
std::vector<float> beamScores;
|
||||
|
||||
dimBatch = (int)batch->size();
|
||||
|
||||
for(size_t i = 0; i < localBeamSize; ++i) {
|
||||
for(size_t j = 0; j < beams.size(); ++j) { // loop over batch entries (active sentences)
|
||||
auto& beam = beams[j];
|
||||
if(i < beam.size()) {
|
||||
auto hyp = beam[i];
|
||||
hypIndices.push_back((IndexType)hyp->GetPrevStateIndex()); // backpointer
|
||||
embIndices.push_back(hyp->GetWord());
|
||||
beamScores.push_back(hyp->GetPathScore());
|
||||
} else { // dummy hypothesis
|
||||
hypIndices.push_back(0);
|
||||
embIndices.push_back(0); // (unused)
|
||||
beamScores.push_back(-9999);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
prevPathScores = graph->constant({(int)localBeamSize, 1, dimBatch, 1},
|
||||
inits::from_vector(beamScores));
|
||||
}
|
||||
|
||||
//**********************************************************************
|
||||
// prepare scores for beam search
|
||||
auto pathScores = prevPathScores;
|
||||
|
||||
for(size_t i = 0; i < scorers_.size(); ++i) {
|
||||
states[i] = scorers_[i]->step(
|
||||
graph, states[i], hypIndices, embIndices, dimBatch, (int)localBeamSize);
|
||||
|
||||
if(scorers_[i]->getWeight() != 1.f)
|
||||
pathScores = pathScores + scorers_[i]->getWeight() * states[i]->getLogProbs();
|
||||
else
|
||||
pathScores = pathScores + states[i]->getLogProbs();
|
||||
}
|
||||
|
||||
// make beams continuous
|
||||
if(dimBatch > 1 && localBeamSize > 1)
|
||||
pathScores = transpose(pathScores, {2, 1, 0, 3}); // check if this is needed for classification, rather not, beamSize and topN is badly defined here
|
||||
|
||||
if(first)
|
||||
graph->forward();
|
||||
else
|
||||
graph->forwardNext();
|
||||
|
||||
//**********************************************************************
|
||||
// suppress specific symbols if not at right positions
|
||||
if(trgUnkId_ != -1 && options_->has("allow-unk")
|
||||
&& !options_->get<bool>("allow-unk"))
|
||||
suppressWord(pathScores, trgUnkId_);
|
||||
for(auto state : states)
|
||||
state->blacklist(pathScores, batch);
|
||||
|
||||
//**********************************************************************
|
||||
// perform beam search and pruning
|
||||
std::vector<unsigned int> outKeys;
|
||||
std::vector<float> outPathScores;
|
||||
|
||||
std::vector<size_t> beamSizes(dimBatch, localBeamSize);
|
||||
getNBestList(beamSizes, pathScores->val(), outPathScores, outKeys, first);
|
||||
|
||||
int dimTrgVoc = pathScores->shape()[-1];
|
||||
beams = toHyps(outKeys,
|
||||
outPathScores,
|
||||
dimTrgVoc,
|
||||
beams,
|
||||
states,
|
||||
localBeamSize,
|
||||
first,
|
||||
batch);
|
||||
|
||||
auto prunedBeams = pruneBeam(beams);
|
||||
for(int i = 0; i < dimBatch; ++i) {
|
||||
if(!beams[i].empty()) {
|
||||
final = final
|
||||
|| histories[i]->size()
|
||||
>= options_->get<float>("max-length-factor")
|
||||
* batch->front()->batchWidth();
|
||||
histories[i]->Add(
|
||||
beams[i], trgEosId_, prunedBeams[i].empty() || final);
|
||||
}
|
||||
}
|
||||
beams = prunedBeams;
|
||||
|
||||
// determine beam size for next sentence, as max over still-active sentences
|
||||
if(!first) {
|
||||
size_t maxBeam = 0;
|
||||
for(auto& beam : beams)
|
||||
if(beam.size() > maxBeam)
|
||||
maxBeam = beam.size();
|
||||
localBeamSize = maxBeam;
|
||||
}
|
||||
first = false;
|
||||
|
||||
} while(localBeamSize != 0 && !final); // end of main loop over output tokens
|
||||
|
||||
return histories;
|
||||
}
|
||||
};
|
||||
} // namespace marian
|
Loading…
Reference in New Issue
Block a user