Merged PR 15925: Training embedder separation with margin losses

* This PR adds training of embedding spaces with better separation based on https://arxiv.org/abs/2007.01852
* We can now train with in-batch negative examples or a handful of hand-constructed negative examples provided in a tsv-file.
This commit is contained in:
Martin Junczys-Dowmunt 2020-11-07 17:46:39 +00:00
parent 3d233ec592
commit d5e773f937
9 changed files with 179 additions and 26 deletions

View File

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- Add --train-embedder-rank for fine-tuning any encoder(-decoder) model for multi-lingual similarity via softmax-margin loss
- Add --logical-epoch that allows to redefine the displayed epoch counter as a multiple of n data epochs, updates or labels. Also allows to define width of fractional part with second argument.
- Add --metrics chrf for computing ChrF according to https://www.aclweb.org/anthology/W15-3049/ and SacreBLEU reference implementation
- Add --after option which is meant to replace --after-batches and --after-epochs and can take label based criteria

@ -1 +1 @@
Subproject commit 75977846abfccd29941e4bfd3c615a111599f7f4
Subproject commit cdad78089484d7817d91c803d6fc7049328e20db

View File

@ -458,7 +458,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"epoch+stalled");
cli.add<std::vector<size_t>>("--lr-decay-start",
"The first number of (epoch, batches, stalled) validations to start learning rate decaying (tuple)",
{10,1});
{10, 1});
cli.add<size_t>("--lr-decay-freq",
"Learning rate decaying frequency for batches, requires --lr-decay-strategy to be batches",
50000);
@ -534,6 +534,11 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.add<bool>("--normalize-gradient",
"Normalize gradient by multiplying with no. devices / total labels");
cli.add<std::vector<std::string>>("--train-embedder-rank",
"Override model configuration and train a embedding similarity ranker with the model encoder, "
"parameters encode margin and an optional normalization factor")
->implicit_val("0.3f 0.0f");
// multi-node training
cli.add<bool>("--multi-node",
"Enable asynchronous multi-node training through MPI (and legacy sync if combined with --sync-sgd)");

View File

@ -37,7 +37,7 @@ public:
Expr build(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
auto embedder = std::dynamic_pointer_cast<EncoderPooler>(model_);
ABORT_IF(!embedder, "Could not cast to EncoderPooler");
return embedder->apply(graph, batch, /*clearGraph=*/true);
return embedder->apply(graph, batch, /*clearGraph=*/true)[0];
}
};
@ -56,7 +56,8 @@ public:
Embed(Ptr<Options> options) : options_(options) {
options_ = options_->with("inference", true,
"shuffle", "none");
"shuffle", "none",
"input-types", std::vector<std::string>({"sequence"}));
// if a similarity is computed then double the input types and vocabs for
// the two encoders that are used in the model.
@ -68,7 +69,8 @@ public:
vDimVocabs.push_back(vDimVocabs.back());
options_ = options_->with("vocabs", vVocabs,
"dim-vocabs", vDimVocabs);
"dim-vocabs", vDimVocabs,
"input-types", std::vector<std::string>(vVocabs.size(), "sequence"));
}
corpus_ = New<Corpus>(options_);

View File

@ -27,7 +27,7 @@ public:
Expr build(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
auto embedder = std::dynamic_pointer_cast<EncoderPooler>(model_);
ABORT_IF(!embedder, "Could not cast to EncoderPooler");
return embedder->apply(graph, batch, /*clearGraph=*/true);
return embedder->apply(graph, batch, /*clearGraph=*/true)[0];
}
};

View File

@ -6,6 +6,7 @@
#include "layers/weight.h"
#include "models/encoder_decoder.h"
#include "models/encoder_classifier.h"
#include "models/encoder_pooler.h"
namespace marian {
namespace models {
@ -130,6 +131,68 @@ public:
}
};
// Wraps an EncoderClassifier so it can produce a cost from raw logits. @TODO: Needs refactoring
class EncoderPoolerRankCost : public ICost {
protected:
Ptr<Options> options_;
const bool inference_{false};
float margin_{0.3f};
float normalizer_{0.0f};
public:
EncoderPoolerRankCost(Ptr<Options> options)
: options_(options),
inference_(options->get<bool>("inference", false)) {
auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {});
ABORT_IF(trainEmbedderRank.empty(), "EncoderPoolerRankCost expects train-embedder-rank to be set");
margin_ = std::stof(trainEmbedderRank[0]);
if(trainEmbedderRank.size() > 1)
normalizer_ = std::stof(trainEmbedderRank[1]);
}
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto encpool = std::static_pointer_cast<EncoderPooler>(model);
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
std::vector<Expr> dotProducts = encpool->apply(graph, corpusBatch, clearGraph);
int dimBatch = dotProducts[0]->shape()[-2];
Ptr<MultiRationalLoss> multiLoss = New<SumMultiRationalLoss>();
ABORT_IF(inference_, "Rank training does not work in inference mode");
ABORT_IF(dotProducts.size() != 3, "Three dot products required for margin loss");
// multi-objective training
auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability
auto exponent = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product
auto dp = exp(exponent);
Expr dn1, dn2;
if(normalizer_ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the magnitude of the sum of negative examples in the denominator.
dn1 = normalizer_ * mean(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
dn2 = normalizer_ * mean(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
} else {
dn1 = sum(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
dn2 = sum(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
}
// We rewrite the loss so it looks more like a log-softmax, presumably more stable?
// Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m)
auto marginLoss1 = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples
auto marginLoss2 = log(dp + dn2) - exponent; // symmetric version of the above with positive example vs negative examples
auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2);
RationalLoss loss(marginLoss, (float)dimBatch);
multiLoss->push_back(loss);
return multiLoss;
}
};
class Trainer : public ICriterionFunction {
protected:
Ptr<IModel> model_;

View File

@ -42,7 +42,7 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override = 0;
virtual Expr apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, bool) = 0;
virtual std::vector<Expr> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, bool) = 0;
virtual Logits build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
@ -204,7 +204,7 @@ public:
/*********************************************************************/
virtual Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
virtual std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
if(clearGraph)
clear(graph);

View File

@ -128,15 +128,24 @@ Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> opti
Ptr<ExpressionGraph> graph = nullptr; // graph unknown at this stage
// clang-format off
if(use == usage::embedding) { // hijacking an EncoderDecoder model for embedding only
int dimVocab = options->get<std::vector<int>>("dim-vocabs")[0];
bool trainEmbedderRank = options->hasAndNotEmpty("train-embedder-rank");
if(use == usage::embedding || trainEmbedderRank) { // hijacking an EncoderDecoder model for embedding only
auto dimVocabs = options->get<std::vector<int>>("dim-vocabs");
size_t fields = trainEmbedderRank ? dimVocabs.size() : 0;
int dimVocab = dimVocabs[0];
Ptr<Options> newOptions;
if(options->get<bool>("compute-similarity")) {
if(options->get<bool>("compute-similarity", false)) {
newOptions = options->with("usage", use,
"original-type", type,
"input-types", std::vector<std::string>({"sequence", "sequence"}),
"dim-vocabs", std::vector<int>(2, dimVocab));
} else if(trainEmbedderRank) {
newOptions = options->with("usage", use,
"original-type", type,
"input-types", std::vector<std::string>(fields, "sequence"),
"dim-vocabs", std::vector<int>(fields, dimVocab));
} else {
newOptions = options->with("usage", use,
"original-type", type,
@ -145,10 +154,15 @@ Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> opti
}
auto res = New<EncoderPooler>(newOptions);
if(options->get<bool>("compute-similarity")) {
if(options->get<bool>("compute-similarity", false)) {
res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph));
res->push_back(models::encoder(newOptions->with("index", 1)).construct(graph));
res->push_back(New<SimPooler>(graph, newOptions->with("type", "sim-pooler")));
} else if(trainEmbedderRank) {
LOG(info, "Using {} input fields for embedder ranking training", fields);
for(int i = 0; i < fields; ++i)
res->push_back(models::encoder(newOptions->with("index", i)).construct(graph));
res->push_back(New<SimPooler>(graph, newOptions->with("type", "sim-pooler")));
} else {
res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph));
if(type == "laser")
@ -400,6 +414,8 @@ Ptr<ICriterionFunction> createCriterionFunctionFromOptions(Ptr<Options> options,
return New<Trainer>(baseModel, New<MNISTCrossEntropyCost>());
#endif
#endif
else if (std::dynamic_pointer_cast<EncoderPooler>(baseModel))
return New<Trainer>(baseModel, New<EncoderPoolerRankCost>(options));
else
ABORT("Criterion function unknown for model type: {}", type);
}

View File

@ -29,7 +29,7 @@ public:
virtual ~PoolerBase() {}
virtual Expr apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, const std::vector<Ptr<EncoderState>>&) = 0;
virtual std::vector<Expr> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, const std::vector<Ptr<EncoderState>>&) = 0;
template <typename T>
T opt(const std::string& key) const {
@ -48,7 +48,7 @@ public:
MaxPooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: PoolerBase(graph, options) {}
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state");
auto context = encoderStates[0]->getContext();
@ -58,7 +58,7 @@ public:
Expr logMask = (1.f - batchMask) * -9999.f;
Expr maxPool = max(context * batchMask + logMask, /*axis=*/-3);
return maxPool;
return {maxPool};
}
void clear() override {}
@ -73,7 +73,7 @@ public:
SlicePooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: PoolerBase(graph, options) {}
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state");
auto context = encoderStates[0]->getContext();
@ -83,7 +83,7 @@ public:
// @TODO: unify this better, this is currently hacky
Expr slicePool = slice(context * batchMask, /*axis=*/-3, 0);
return slicePool;
return {slicePool};
}
void clear() override {}
@ -98,8 +98,8 @@ public:
SimPooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: PoolerBase(graph, options) {}
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() != 2, "SimPooler expects exactly two encoder states");
std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
ABORT_IF(encoderStates.size() < 2, "SimPooler expects at least two encoder states not {}", encoderStates.size());
std::vector<Expr> vecs;
for(auto encoderState : encoderStates) {
@ -108,7 +108,7 @@ public:
Expr pool;
auto type = options_->get<std::string>("original-type");
if(type == "laser") {
if(type == "laser" || type == "laser-sim") {
// LASER models do a max pool here
Expr logMask = (1.f - batchMask) * -9999.f;
pool = max(context * batchMask + logMask, /*axis=*/-3);
@ -119,17 +119,83 @@ public:
// @TODO: make SimPooler take Pooler objects as arguments then it won't need to know this.
ABORT("Don't know what type of pooler to use for model type {}", type);
}
vecs.push_back(pool);
}
auto scalars = scalar_product(vecs[0], vecs[1], /*axis*/-1);
auto length1 = sqrt(sum(square(vecs[0]), /*axis=*/-1));
auto length2 = sqrt(sum(square(vecs[1]), /*axis=*/-1));
std::vector<Expr> outputs;
bool trainRank = options_->hasAndNotEmpty("train-embedder-rank");
auto cosine = scalars / ( length1 * length2 );
if(!trainRank) { // inference, compute one cosine similarity only
ABORT_IF(vecs.size() != 2, "We are expecting two inputs for similarity computation");
return cosine;
// efficiently compute vector length with bdot
auto vnorm = [](Expr e) {
int dimModel = e->shape()[-1];
int dimBatch = e->shape()[-2];
e = reshape(e, {dimBatch, 1, dimModel});
return reshape(sqrt(bdot(e, e, false, true)), {dimBatch, 1});
};
auto dotProduct = scalar_product(vecs[0], vecs[1], /*axis*/-1);
auto length0 = vnorm(vecs[0]); // will be hashed and reused in the graph
auto length1 = vnorm(vecs[1]);
auto cosine = dotProduct / ( length0 * length1 );
cosine = maximum(0, cosine); // clip to [0, 1] - should we actually do that?
outputs.push_back(cosine);
} else { // compute outputs for embedding similarity ranking
if(vecs.size() == 2) { // implies we are sampling negative examples from the batch, since otherwise there is nothing to train
LOG_ONCE(info, "Sampling negative examples from batch");
auto src = vecs[0];
auto trg = vecs[1];
int dimModel = src->shape()[-1];
int dimBatch = src->shape()[-2];
src = reshape(src, {dimBatch, dimModel});
trg = reshape(trg, {dimBatch, dimModel});
// compute cosines between every batch entry, this produces the whole dimBatch x dimBatch matrix
auto dotProduct = dot(src, trg, false, true); // [dimBatch, dimBatch] - computes dot product matrix
auto positiveMask = dotProduct->graph()->constant({dimBatch, dimBatch}, inits::eye()); // a mask for the diagonal (positive examples are on the diagonal)
auto negativeMask = 1.f - positiveMask; // mask for all negative examples;
auto positive = sum(dotProduct * positiveMask, /*axis=*/-1); // we sum across last dim in order to get a column vector of positve examples (everything else is zero)
outputs.push_back(positive);
auto negative1 = dotProduct * negativeMask; // get negative examples for src -> trg (in a row)
outputs.push_back(negative1);
auto negative2 = transpose(negative1); // get negative examples for trg -> src via transpose so they are located in a row
outputs.push_back(transpose(negative2));
} else {
LOG_ONCE(info, "Using provided {} negative examples", vecs.size() - 2);
// For inference and training with given set of negative examples provided in additional streams.
// Assuming that enc0 is query, enc1 is positive example and remaining encoders are optional negative examples. Here we only use column vectors [dimBatch, 1]
auto positive = scalar_product(vecs[0], vecs[1], /*axis*/-1);
outputs.push_back(positive); // first tensor contains similarity between anchor and positive example
std::vector<Expr> dotProductsNegative1, dotProductsNegative2;
for(int i = 2; i < vecs.size(); ++i) {
// compute similarity with anchor
auto negative1 = scalar_product(vecs[0], vecs[i], /*axis*/-1);
dotProductsNegative1.push_back(negative1);
// for negative examples also add symmetric dot product with the positive example
auto negative2 = scalar_product(vecs[1], vecs[i], /*axis*/-1);
dotProductsNegative2.push_back(negative2);
}
auto negative1 = concatenate(dotProductsNegative1, /*axis=*/-1);
outputs.push_back(negative1); // second tensor contains similarities between anchor and all negative example
auto negative2 = concatenate(dotProductsNegative2, /*axis=*/-1);
outputs.push_back(negative2); // third tensor contains similarities between positive and all negative example (symmetric)
}
}
return outputs;
}
void clear() override {}