mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 15925: Training embedder separation with margin losses
* This PR adds training of embedding spaces with better separation based on https://arxiv.org/abs/2007.01852 * We can now train with in-batch negative examples or a handful of hand-constructed negative examples provided in a tsv-file.
This commit is contained in:
parent
3d233ec592
commit
d5e773f937
@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Add --train-embedder-rank for fine-tuning any encoder(-decoder) model for multi-lingual similarity via softmax-margin loss
|
||||
- Add --logical-epoch that allows to redefine the displayed epoch counter as a multiple of n data epochs, updates or labels. Also allows to define width of fractional part with second argument.
|
||||
- Add --metrics chrf for computing ChrF according to https://www.aclweb.org/anthology/W15-3049/ and SacreBLEU reference implementation
|
||||
- Add --after option which is meant to replace --after-batches and --after-epochs and can take label based criteria
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit 75977846abfccd29941e4bfd3c615a111599f7f4
|
||||
Subproject commit cdad78089484d7817d91c803d6fc7049328e20db
|
@ -458,7 +458,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
||||
"epoch+stalled");
|
||||
cli.add<std::vector<size_t>>("--lr-decay-start",
|
||||
"The first number of (epoch, batches, stalled) validations to start learning rate decaying (tuple)",
|
||||
{10,1});
|
||||
{10, 1});
|
||||
cli.add<size_t>("--lr-decay-freq",
|
||||
"Learning rate decaying frequency for batches, requires --lr-decay-strategy to be batches",
|
||||
50000);
|
||||
@ -534,6 +534,11 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
||||
cli.add<bool>("--normalize-gradient",
|
||||
"Normalize gradient by multiplying with no. devices / total labels");
|
||||
|
||||
cli.add<std::vector<std::string>>("--train-embedder-rank",
|
||||
"Override model configuration and train a embedding similarity ranker with the model encoder, "
|
||||
"parameters encode margin and an optional normalization factor")
|
||||
->implicit_val("0.3f 0.0f");
|
||||
|
||||
// multi-node training
|
||||
cli.add<bool>("--multi-node",
|
||||
"Enable asynchronous multi-node training through MPI (and legacy sync if combined with --sync-sgd)");
|
||||
|
@ -37,7 +37,7 @@ public:
|
||||
Expr build(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
|
||||
auto embedder = std::dynamic_pointer_cast<EncoderPooler>(model_);
|
||||
ABORT_IF(!embedder, "Could not cast to EncoderPooler");
|
||||
return embedder->apply(graph, batch, /*clearGraph=*/true);
|
||||
return embedder->apply(graph, batch, /*clearGraph=*/true)[0];
|
||||
}
|
||||
};
|
||||
|
||||
@ -56,7 +56,8 @@ public:
|
||||
Embed(Ptr<Options> options) : options_(options) {
|
||||
|
||||
options_ = options_->with("inference", true,
|
||||
"shuffle", "none");
|
||||
"shuffle", "none",
|
||||
"input-types", std::vector<std::string>({"sequence"}));
|
||||
|
||||
// if a similarity is computed then double the input types and vocabs for
|
||||
// the two encoders that are used in the model.
|
||||
@ -68,7 +69,8 @@ public:
|
||||
vDimVocabs.push_back(vDimVocabs.back());
|
||||
|
||||
options_ = options_->with("vocabs", vVocabs,
|
||||
"dim-vocabs", vDimVocabs);
|
||||
"dim-vocabs", vDimVocabs,
|
||||
"input-types", std::vector<std::string>(vVocabs.size(), "sequence"));
|
||||
}
|
||||
|
||||
corpus_ = New<Corpus>(options_);
|
||||
|
@ -27,7 +27,7 @@ public:
|
||||
Expr build(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
|
||||
auto embedder = std::dynamic_pointer_cast<EncoderPooler>(model_);
|
||||
ABORT_IF(!embedder, "Could not cast to EncoderPooler");
|
||||
return embedder->apply(graph, batch, /*clearGraph=*/true);
|
||||
return embedder->apply(graph, batch, /*clearGraph=*/true)[0];
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "layers/weight.h"
|
||||
#include "models/encoder_decoder.h"
|
||||
#include "models/encoder_classifier.h"
|
||||
#include "models/encoder_pooler.h"
|
||||
|
||||
namespace marian {
|
||||
namespace models {
|
||||
@ -130,6 +131,68 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Wraps an EncoderClassifier so it can produce a cost from raw logits. @TODO: Needs refactoring
|
||||
class EncoderPoolerRankCost : public ICost {
|
||||
protected:
|
||||
Ptr<Options> options_;
|
||||
const bool inference_{false};
|
||||
float margin_{0.3f};
|
||||
float normalizer_{0.0f};
|
||||
|
||||
public:
|
||||
EncoderPoolerRankCost(Ptr<Options> options)
|
||||
: options_(options),
|
||||
inference_(options->get<bool>("inference", false)) {
|
||||
auto trainEmbedderRank = options->get<std::vector<std::string>>("train-embedder-rank", {});
|
||||
ABORT_IF(trainEmbedderRank.empty(), "EncoderPoolerRankCost expects train-embedder-rank to be set");
|
||||
|
||||
margin_ = std::stof(trainEmbedderRank[0]);
|
||||
if(trainEmbedderRank.size() > 1)
|
||||
normalizer_ = std::stof(trainEmbedderRank[1]);
|
||||
}
|
||||
|
||||
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
|
||||
Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
bool clearGraph = true) override {
|
||||
|
||||
auto encpool = std::static_pointer_cast<EncoderPooler>(model);
|
||||
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
|
||||
std::vector<Expr> dotProducts = encpool->apply(graph, corpusBatch, clearGraph);
|
||||
|
||||
int dimBatch = dotProducts[0]->shape()[-2];
|
||||
Ptr<MultiRationalLoss> multiLoss = New<SumMultiRationalLoss>();
|
||||
|
||||
ABORT_IF(inference_, "Rank training does not work in inference mode");
|
||||
ABORT_IF(dotProducts.size() != 3, "Three dot products required for margin loss");
|
||||
|
||||
// multi-objective training
|
||||
auto maxDot = max(concatenate(dotProducts, -1), -1); // compute maximum for numeric stability
|
||||
auto exponent = dotProducts[0] - maxDot - margin_; // substract maximum and margin from dot product
|
||||
auto dp = exp(exponent);
|
||||
|
||||
Expr dn1, dn2;
|
||||
if(normalizer_ != 0.0f) { // the normalizer may be useful for fluctuating batch sizes since it limits the magnitude of the sum of negative examples in the denominator.
|
||||
dn1 = normalizer_ * mean(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
|
||||
dn2 = normalizer_ * mean(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
|
||||
} else {
|
||||
dn1 = sum(exp(dotProducts[1] - maxDot), -1); // dot product of anchor and first negative example
|
||||
dn2 = sum(exp(dotProducts[2] - maxDot), -1); // dot product of positive examples and first negative example
|
||||
}
|
||||
|
||||
// We rewrite the loss so it looks more like a log-softmax, presumably more stable?
|
||||
// Let dp = exp(phi - m) then -log(dp / (dp + sum(dn))) = -log(dp) + log(dp + sum(dn)) = log(dp + sum(dn)) - log(dp) = log(dp + sum(dn)) - (phi - m)
|
||||
auto marginLoss1 = log(dp + dn1) - exponent; // softmax-margin loss for anchor vs negative examples
|
||||
auto marginLoss2 = log(dp + dn2) - exponent; // symmetric version of the above with positive example vs negative examples
|
||||
auto marginLoss = sum(marginLoss1 + marginLoss2, /*axis=*/-2);
|
||||
|
||||
RationalLoss loss(marginLoss, (float)dimBatch);
|
||||
multiLoss->push_back(loss);
|
||||
|
||||
return multiLoss;
|
||||
}
|
||||
};
|
||||
|
||||
class Trainer : public ICriterionFunction {
|
||||
protected:
|
||||
Ptr<IModel> model_;
|
||||
|
@ -42,7 +42,7 @@ public:
|
||||
|
||||
virtual void clear(Ptr<ExpressionGraph> graph) override = 0;
|
||||
|
||||
virtual Expr apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, bool) = 0;
|
||||
virtual std::vector<Expr> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, bool) = 0;
|
||||
|
||||
virtual Logits build(Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::Batch> batch,
|
||||
@ -204,7 +204,7 @@ public:
|
||||
|
||||
/*********************************************************************/
|
||||
|
||||
virtual Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
|
||||
virtual std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
|
||||
if(clearGraph)
|
||||
clear(graph);
|
||||
|
||||
|
@ -128,15 +128,24 @@ Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> opti
|
||||
Ptr<ExpressionGraph> graph = nullptr; // graph unknown at this stage
|
||||
// clang-format off
|
||||
|
||||
if(use == usage::embedding) { // hijacking an EncoderDecoder model for embedding only
|
||||
int dimVocab = options->get<std::vector<int>>("dim-vocabs")[0];
|
||||
bool trainEmbedderRank = options->hasAndNotEmpty("train-embedder-rank");
|
||||
if(use == usage::embedding || trainEmbedderRank) { // hijacking an EncoderDecoder model for embedding only
|
||||
|
||||
auto dimVocabs = options->get<std::vector<int>>("dim-vocabs");
|
||||
size_t fields = trainEmbedderRank ? dimVocabs.size() : 0;
|
||||
int dimVocab = dimVocabs[0];
|
||||
|
||||
Ptr<Options> newOptions;
|
||||
if(options->get<bool>("compute-similarity")) {
|
||||
if(options->get<bool>("compute-similarity", false)) {
|
||||
newOptions = options->with("usage", use,
|
||||
"original-type", type,
|
||||
"input-types", std::vector<std::string>({"sequence", "sequence"}),
|
||||
"dim-vocabs", std::vector<int>(2, dimVocab));
|
||||
} else if(trainEmbedderRank) {
|
||||
newOptions = options->with("usage", use,
|
||||
"original-type", type,
|
||||
"input-types", std::vector<std::string>(fields, "sequence"),
|
||||
"dim-vocabs", std::vector<int>(fields, dimVocab));
|
||||
} else {
|
||||
newOptions = options->with("usage", use,
|
||||
"original-type", type,
|
||||
@ -145,10 +154,15 @@ Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> opti
|
||||
}
|
||||
|
||||
auto res = New<EncoderPooler>(newOptions);
|
||||
if(options->get<bool>("compute-similarity")) {
|
||||
if(options->get<bool>("compute-similarity", false)) {
|
||||
res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph));
|
||||
res->push_back(models::encoder(newOptions->with("index", 1)).construct(graph));
|
||||
res->push_back(New<SimPooler>(graph, newOptions->with("type", "sim-pooler")));
|
||||
} else if(trainEmbedderRank) {
|
||||
LOG(info, "Using {} input fields for embedder ranking training", fields);
|
||||
for(int i = 0; i < fields; ++i)
|
||||
res->push_back(models::encoder(newOptions->with("index", i)).construct(graph));
|
||||
res->push_back(New<SimPooler>(graph, newOptions->with("type", "sim-pooler")));
|
||||
} else {
|
||||
res->push_back(models::encoder(newOptions->with("index", 0)).construct(graph));
|
||||
if(type == "laser")
|
||||
@ -400,6 +414,8 @@ Ptr<ICriterionFunction> createCriterionFunctionFromOptions(Ptr<Options> options,
|
||||
return New<Trainer>(baseModel, New<MNISTCrossEntropyCost>());
|
||||
#endif
|
||||
#endif
|
||||
else if (std::dynamic_pointer_cast<EncoderPooler>(baseModel))
|
||||
return New<Trainer>(baseModel, New<EncoderPoolerRankCost>(options));
|
||||
else
|
||||
ABORT("Criterion function unknown for model type: {}", type);
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ public:
|
||||
|
||||
virtual ~PoolerBase() {}
|
||||
|
||||
virtual Expr apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, const std::vector<Ptr<EncoderState>>&) = 0;
|
||||
virtual std::vector<Expr> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, const std::vector<Ptr<EncoderState>>&) = 0;
|
||||
|
||||
template <typename T>
|
||||
T opt(const std::string& key) const {
|
||||
@ -48,7 +48,7 @@ public:
|
||||
MaxPooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: PoolerBase(graph, options) {}
|
||||
|
||||
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
|
||||
std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
|
||||
ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state");
|
||||
|
||||
auto context = encoderStates[0]->getContext();
|
||||
@ -58,7 +58,7 @@ public:
|
||||
Expr logMask = (1.f - batchMask) * -9999.f;
|
||||
Expr maxPool = max(context * batchMask + logMask, /*axis=*/-3);
|
||||
|
||||
return maxPool;
|
||||
return {maxPool};
|
||||
}
|
||||
|
||||
void clear() override {}
|
||||
@ -73,7 +73,7 @@ public:
|
||||
SlicePooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: PoolerBase(graph, options) {}
|
||||
|
||||
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
|
||||
std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
|
||||
ABORT_IF(encoderStates.size() != 1, "Pooler expects exactly one encoder state");
|
||||
|
||||
auto context = encoderStates[0]->getContext();
|
||||
@ -83,7 +83,7 @@ public:
|
||||
// @TODO: unify this better, this is currently hacky
|
||||
Expr slicePool = slice(context * batchMask, /*axis=*/-3, 0);
|
||||
|
||||
return slicePool;
|
||||
return {slicePool};
|
||||
}
|
||||
|
||||
void clear() override {}
|
||||
@ -98,8 +98,8 @@ public:
|
||||
SimPooler(Ptr<ExpressionGraph> graph, Ptr<Options> options)
|
||||
: PoolerBase(graph, options) {}
|
||||
|
||||
Expr apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
|
||||
ABORT_IF(encoderStates.size() != 2, "SimPooler expects exactly two encoder states");
|
||||
std::vector<Expr> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
|
||||
ABORT_IF(encoderStates.size() < 2, "SimPooler expects at least two encoder states not {}", encoderStates.size());
|
||||
|
||||
std::vector<Expr> vecs;
|
||||
for(auto encoderState : encoderStates) {
|
||||
@ -108,7 +108,7 @@ public:
|
||||
|
||||
Expr pool;
|
||||
auto type = options_->get<std::string>("original-type");
|
||||
if(type == "laser") {
|
||||
if(type == "laser" || type == "laser-sim") {
|
||||
// LASER models do a max pool here
|
||||
Expr logMask = (1.f - batchMask) * -9999.f;
|
||||
pool = max(context * batchMask + logMask, /*axis=*/-3);
|
||||
@ -119,17 +119,83 @@ public:
|
||||
// @TODO: make SimPooler take Pooler objects as arguments then it won't need to know this.
|
||||
ABORT("Don't know what type of pooler to use for model type {}", type);
|
||||
}
|
||||
|
||||
vecs.push_back(pool);
|
||||
}
|
||||
|
||||
auto scalars = scalar_product(vecs[0], vecs[1], /*axis*/-1);
|
||||
auto length1 = sqrt(sum(square(vecs[0]), /*axis=*/-1));
|
||||
auto length2 = sqrt(sum(square(vecs[1]), /*axis=*/-1));
|
||||
std::vector<Expr> outputs;
|
||||
bool trainRank = options_->hasAndNotEmpty("train-embedder-rank");
|
||||
|
||||
auto cosine = scalars / ( length1 * length2 );
|
||||
if(!trainRank) { // inference, compute one cosine similarity only
|
||||
ABORT_IF(vecs.size() != 2, "We are expecting two inputs for similarity computation");
|
||||
|
||||
return cosine;
|
||||
// efficiently compute vector length with bdot
|
||||
auto vnorm = [](Expr e) {
|
||||
int dimModel = e->shape()[-1];
|
||||
int dimBatch = e->shape()[-2];
|
||||
e = reshape(e, {dimBatch, 1, dimModel});
|
||||
return reshape(sqrt(bdot(e, e, false, true)), {dimBatch, 1});
|
||||
};
|
||||
|
||||
auto dotProduct = scalar_product(vecs[0], vecs[1], /*axis*/-1);
|
||||
auto length0 = vnorm(vecs[0]); // will be hashed and reused in the graph
|
||||
auto length1 = vnorm(vecs[1]);
|
||||
auto cosine = dotProduct / ( length0 * length1 );
|
||||
cosine = maximum(0, cosine); // clip to [0, 1] - should we actually do that?
|
||||
outputs.push_back(cosine);
|
||||
} else { // compute outputs for embedding similarity ranking
|
||||
if(vecs.size() == 2) { // implies we are sampling negative examples from the batch, since otherwise there is nothing to train
|
||||
LOG_ONCE(info, "Sampling negative examples from batch");
|
||||
|
||||
auto src = vecs[0];
|
||||
auto trg = vecs[1];
|
||||
|
||||
int dimModel = src->shape()[-1];
|
||||
int dimBatch = src->shape()[-2];
|
||||
|
||||
src = reshape(src, {dimBatch, dimModel});
|
||||
trg = reshape(trg, {dimBatch, dimModel});
|
||||
|
||||
// compute cosines between every batch entry, this produces the whole dimBatch x dimBatch matrix
|
||||
auto dotProduct = dot(src, trg, false, true); // [dimBatch, dimBatch] - computes dot product matrix
|
||||
|
||||
auto positiveMask = dotProduct->graph()->constant({dimBatch, dimBatch}, inits::eye()); // a mask for the diagonal (positive examples are on the diagonal)
|
||||
auto negativeMask = 1.f - positiveMask; // mask for all negative examples;
|
||||
|
||||
auto positive = sum(dotProduct * positiveMask, /*axis=*/-1); // we sum across last dim in order to get a column vector of positve examples (everything else is zero)
|
||||
outputs.push_back(positive);
|
||||
|
||||
auto negative1 = dotProduct * negativeMask; // get negative examples for src -> trg (in a row)
|
||||
outputs.push_back(negative1);
|
||||
|
||||
auto negative2 = transpose(negative1); // get negative examples for trg -> src via transpose so they are located in a row
|
||||
outputs.push_back(transpose(negative2));
|
||||
} else {
|
||||
LOG_ONCE(info, "Using provided {} negative examples", vecs.size() - 2);
|
||||
|
||||
// For inference and training with given set of negative examples provided in additional streams.
|
||||
// Assuming that enc0 is query, enc1 is positive example and remaining encoders are optional negative examples. Here we only use column vectors [dimBatch, 1]
|
||||
auto positive = scalar_product(vecs[0], vecs[1], /*axis*/-1);
|
||||
outputs.push_back(positive); // first tensor contains similarity between anchor and positive example
|
||||
|
||||
std::vector<Expr> dotProductsNegative1, dotProductsNegative2;
|
||||
for(int i = 2; i < vecs.size(); ++i) {
|
||||
// compute similarity with anchor
|
||||
auto negative1 = scalar_product(vecs[0], vecs[i], /*axis*/-1);
|
||||
dotProductsNegative1.push_back(negative1);
|
||||
|
||||
// for negative examples also add symmetric dot product with the positive example
|
||||
auto negative2 = scalar_product(vecs[1], vecs[i], /*axis*/-1);
|
||||
dotProductsNegative2.push_back(negative2);
|
||||
}
|
||||
auto negative1 = concatenate(dotProductsNegative1, /*axis=*/-1);
|
||||
outputs.push_back(negative1); // second tensor contains similarities between anchor and all negative example
|
||||
|
||||
auto negative2 = concatenate(dotProductsNegative2, /*axis=*/-1);
|
||||
outputs.push_back(negative2); // third tensor contains similarities between positive and all negative example (symmetric)
|
||||
}
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void clear() override {}
|
||||
|
Loading…
Reference in New Issue
Block a user