Merged PR 23767: More principled sampling and force-decoding

This PR adds correct force-decoding and more principled sampling, both should now work for ensembles, batches and with beam search.
This commit is contained in:
Marcin Junczys-Dowmunt 2022-09-16 22:53:08 +00:00
parent e13053a6f2
commit 76964791ad
10 changed files with 200 additions and 102 deletions

View File

@ -681,6 +681,11 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
cli.add<std::string>("--alignment",
"Return word alignment. Possible values: 0.0-1.0, hard, soft")
->implicit_val("1");
cli.add<bool>("--force-decode",
"Use force-decoding of given prefixes. Forces decoding to follow vocab IDs from last stream in the batch (or the first stream, if there is only one). "
"Use either as `./marian-decoder --force-decode --input source.txt prefixes.txt [...]` where inputs and prefixes align on line-level or as "
"`paste source.txt prefixes.txt | ./marian-decoder --force-decode --tsv --tsv-fields 2 [...]` when reading from stdin."
);
cli.add<bool>("--word-scores",
"Print word-level scores. One score per subword unit, not normalized even if --normalize");
cli.add<std::string/*SchedulerPeriod*/>("--stat-freq",
@ -709,9 +714,10 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
cli.add<std::vector<float>>("--weights",
"Scorer weights");
cli.add<std::vector<std::string>>("--output-sampling",
"Noise output layer with gumbel noise. Implicit default is 'full' for sampling from full distribution. "
" Also accepts 'topk num' (e.g. topk 100) for top-100 sampling.")
->implicit_val("full");
"Noise output layer with gumbel noise. Implicit default is 'full 1.0' for sampling from full distribution"
" with softmax temperature 1.0. Also accepts 'topk num temp' (e.g. topk 100 0.1) for top-100 sampling with"
" temperature 0.1")
->implicit_val("full 1.0");
cli.add<std::vector<int>>("--output-approx-knn",
"Use approximate knn search in output layer (currently only in transformer)")
->implicit_val("100 1024");

View File

@ -347,11 +347,18 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate, size_t seed)
auto vocabDims = options_->get<std::vector<int>>("dim-vocabs");
vocabDims.resize(numVocs, 0);
for(size_t i = 0; i + 1 < numVocs; ++i) {
// when force-decoding we want the last vocab to be part of the batch,
// hence we do not drop it from the input batch.
bool forceDecoding = options_->get<bool>("force-decode", false);
size_t shift = !forceDecoding ? 1 : 0;
for(size_t i = 0; i + shift < numVocs; ++i) {
Ptr<Vocab> vocab = New<Vocab>(options_, i);
vocabDims[i] = (int) vocab->load(vocabPaths[i], maxVocabs[i]);
vocabs_.emplace_back(vocab);
}
}
// TODO: As above, this is not nice as it modifies the option object and needs to expose the changes
// outside the corpus as models need to know about the vocabulary size; extract the vocab
// creation functionality from the class.
@ -368,10 +375,11 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate, size_t seed)
}
}
ABORT_IF(!tsv_ && vocabs_.size() != files_.size(),
size_t numStreams = files_.size();
ABORT_IF(!tsv_ && vocabs_.size() != numStreams,
"Number of {} files ({}) and vocab files ({}) does not agree",
training ? "corpus" : "input",
files_.size(),
numStreams,
vocabs_.size());
// Handle guided alignment and data weighting files. Alignments and weights in TSV input were

View File

@ -1040,7 +1040,7 @@ struct GatherNodeOp : public NaryNodeOp {
NodeOps backwardOps() override {
return {NodeOp(
// @TODO: rename to scatter
Insert</*add=*/true>(child(0)->grad(), adj_, child(1)->val(), axis_))};
Insert</*add=*/true>(child(0)->grad(), adj_, /*indices=*/child(1)->val(), axis_))};
}
Shape newShape(Expr a, int axis, Expr indices) {
@ -1097,7 +1097,7 @@ struct ScatterNodeOp : public NaryNodeOp {
NodeOps forwardOps() override {
return {NodeOp(
CopyCast(val_, child(0)->val()); // @TODO: use normal copy
Insert</*add=*/false>(val_, child(2)->val(), child(1)->val(), axis_)
Insert</*add=*/false>(val_, /*source=*/child(2)->val(), /*indices=*/child(1)->val(), axis_)
)};
}
@ -1107,7 +1107,7 @@ struct ScatterNodeOp : public NaryNodeOp {
Shape newShape(Expr a, int axis, Expr indices, Expr source) {
ABORT_IF(axis != -1, "only last dimensions");
ABORT_IF(indices->shape() != source->shape(), "Shapes must match");
// ABORT_IF(indices->shape() != source->shape(), "Shapes must match"); or broadcast
Shape shape = a->shape();
// @TODO: do proper checking

View File

@ -10,40 +10,5 @@ Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) {
return state;
}
Ptr<DecoderState> GumbelSoftmaxStep::apply(Ptr<DecoderState> state) {
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
[](Expr logits) { // lemma gets gumbelled
return logsoftmax(logits + constant_like(logits, inits::gumbel()));
},
logsoftmax)); // factors don't
return state;
}
TopkGumbelSoftmaxStep::TopkGumbelSoftmaxStep(int k) : k_{k} {}
Ptr<DecoderState> TopkGumbelSoftmaxStep::apply(Ptr<DecoderState> state) {
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
[=](Expr logits) { // lemma gets gumbelled
// create logits-sized tensor consisting only of invalid path scores
float invalidPathScore = NumericLimits<float>(logits->value_type()).lowest;
Expr invalidLogits = constant_like(logits, inits::fromValue(invalidPathScore));
// select top-k values
Expr val, idx;
std::tie(val, idx) = topk(logits, k_, /*axis=*/-1, /*descending=*/true);
// uncomment below to display probability mass in top-k selection
// debug(sum(gather(softmax(logits), -1, idx), -1), "sum");
// Add Gumbel noise to top-k values only and compute logsoftmax, used for argmax sampling later in beam-search
Expr gumbelVal = logsoftmax(val + constant_like(val, inits::gumbel()));
// Scatter gumbelled values back into logits to fill with usable values
return scatter(invalidLogits, -1, idx, gumbelVal);
},
logsoftmax)); // factors don't
return state;
}
} // namespace models
} // namespace marian

View File

@ -297,32 +297,6 @@ public:
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
};
// Gumbel-max noising for sampling during translation.
// Produces accurate sampling with beam=1. Turn on
// with --output-sampling [full] during translation
// with marian-decoder for samnpling from the full
// softmax distribution.
class GumbelSoftmaxStep : public ILogProbStep {
public:
virtual ~GumbelSoftmaxStep() {}
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
};
// Gumbel-max noising for top-k sampling during translation.
// Produces accurate sampling with beam=1. Turn on
// with --output-sampling topk [10] during translation
// with marian-decoder for top-10 sampling.
class TopkGumbelSoftmaxStep : public ILogProbStep {
private:
int k_{1};
public:
TopkGumbelSoftmaxStep(int k);
virtual ~TopkGumbelSoftmaxStep() {}
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
};
// class to wrap an IEncoderDecoder and a ILogProbStep that are executed in sequence,
// wrapped again in the IEncoderDecoder interface
// @TODO: seems we are conflating an interface defition with its implementation?

View File

@ -370,28 +370,7 @@ Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
// add (log)softmax if requested
if (use == usage::translation) {
if(std::dynamic_pointer_cast<EncoderDecoder>(baseModel)) {
if(options->hasAndNotEmpty("output-sampling")) {
auto sampling = options->get<std::vector<std::string>>("output-sampling", {});
std::string method = sampling.size() > 0 ? sampling[0] : "full";
if(method == "0") { /*for backwards-compat when output-sampling: false in yaml file*/
// do normal decoding
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
} else if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) {
LOG(info, "Output sampling from the full softmax distribution");
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
} else if(method == "topk") {
int k = sampling.size() > 1 ? std::stoi(sampling[1]) : 10;
if(k == 1)
LOG(info, "Output sampling with k=1 is equivalent to beam search with beam size 1");
LOG(info, "Output sampling via top-{} sampling", k);
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<TopkGumbelSoftmaxStep>(k));
} else {
ABORT("Unknown sampling method: {}", method);
}
} else {
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
}
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
}
#ifdef COMPILE_EXAMPLES
// note: 'usage::translation' here means 'inference'

View File

@ -199,6 +199,8 @@ void CopyCastFrom(Tensor out, const T* in, int length) {
#endif
} else if(out->type() == Type::float64) {
CopyCastTo<add>(out->data<double>(), in, length);
} else if(out->type() == Type::uint32) {
CopyCastTo<add>(out->data<uint32_t>(), in, length);
} else {
ABORT("CopyCastTo to type {} not implemented", out->type());
}
@ -313,6 +315,8 @@ void Concatenate1(Tensor out, const std::vector<Tensor>& inputs) {
} else if(out->type() == Type::float16) {
gInsertCols<false><<<blocks, threads>>>(out->data<half>(), in->data<half>(), rows, cols_in, cols_out, cols_in, offset, 0);
#endif
} else if(out->type() == Type::uint32) {
gInsertCols<false><<<blocks, threads>>>(out->data<uint32_t>(), in->data<uint32_t>(), rows, cols_in, cols_out, cols_in, offset, 0);
} else {
ABORT("Concatenate1 not implemented for type {}", out->type());
}
@ -392,6 +396,14 @@ void Concatenate2(Tensor out, Tensor in1, Tensor in2) {
in2->data<half>(),
rowStride2);
#endif
} else if(out->type() == Type::uint32) {
gJoin2<<<blocks, threads>>>(out->data<uint32_t>(),
rowBatch,
cols,
in1->data<uint32_t>(),
rowStride1,
in2->data<uint32_t>(),
rowStride2);
} else {
ABORT("Concatenate2 not implemented for type {}", out->type());
}

View File

@ -1,10 +1,9 @@
#include "translator/beam_search.h"
#include "data/factored_vocab.h"
#include "translator/helpers.h"
#include "translator/nth_element.h"
#include "data/shortlist.h"
#include "common/utils.h"
#include "data/factored_vocab.h"
#include "data/shortlist.h"
#include "translator/beam_search.h"
#include "translator/helpers.h"
#include "translator/sampling.h"
namespace marian {
@ -316,6 +315,8 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
suppressedWordIndices = graph->indices(suppressed);
}
auto distMod = New<DistModifier>(options_, batch, INVALID_PATH_SCORE);
// the decoding process updates the following state information in each output time step:
// - beams: array [origDimBatch] of array [maxBeamSize] of Hypothesis
// - current output time step's set of active hypotheses, aka active search space
@ -413,9 +414,9 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
//**********************************************************************
// compute expanded path scores with word prediction probs from all scorers
auto expandedPathScores = prevPathScores; // will become [maxBeamSize, 1, currDimBatch, dimVocab]
Expr logProbs;
Expr stepScores;
for(size_t i = 0; i < scorers_.size(); ++i) {
Expr logProbs;
if (factorGroup == 0) {
// compute output probabilities for current output time step
// - uses hypIndices[index in beam, 1, batch index, 1] to reorder scorer state to reflect the top-N in beams[][]
@ -449,10 +450,19 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
logProbs = states[i]->getLogProbs().getFactoredLogits(factorGroup, /*shortlist=*/ nullptr, hypIndices, maxBeamSize); // [maxBeamSize, 1, currentDimBatch, dimVocab]
}
// expand all hypotheses, [maxBeamSize, 1, currentDimBatch, 1] -> [maxBeamSize, 1, currentDimBatch, dimVocab]
expandedPathScores = expandedPathScores + scorers_[i]->getWeight() * logProbs;
if(i == 0)
stepScores = scorers_[i]->getWeight() * logProbs;
else
stepScores = stepScores + scorers_[i]->getWeight() * logProbs;
}
if(factorGroup == 0) {
stepScores = distMod->force(stepScores, (int)t, (int)maxBeamSize, batchIndices);
stepScores = distMod->sample(stepScores);
}
// make beams continuous
auto expandedPathScores = prevPathScores + stepScores; // will become [maxBeamSize, 1, currDimBatch, dimVocab]
expandedPathScores = swapAxes(expandedPathScores, 0, 2); // -> [currentDimBatch, 1, maxBeamSize, dimVocab]
// perform NN computation
@ -463,6 +473,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
//**********************************************************************
// suppress specific symbols if not at right positions
// @TODO: move this to DistributionModifier
if(suppressedWordIndices && factorGroup == 0)
suppressWords(expandedPathScores, suppressedWordIndices);
@ -477,6 +488,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
/*out*/ nBestPathScores,
/*out*/ nBestKeys,
/*first=*/t == 0 && factorGroup == 0); // @TODO: this is only used for checking presently, and should be removed altogether
// Now, nBestPathScores contain N-best expandedPathScores for each batch and beam,
// and nBestKeys for each their original location (batchIdx, beamHypIdx, word).

View File

@ -3,6 +3,7 @@
#include "marian.h"
#include "translator/history.h"
#include "translator/scorers.h"
#include "translator/nth_element.h"
namespace marian {

141
src/translator/sampling.h Normal file
View File

@ -0,0 +1,141 @@
namespace marian {
class DistModifier {
private:
Ptr<Options> options_;
bool forceDecode_{false};
bool sampling_{false};
std::string samplingMethod_;
int topk_{10};
float temperature_{1.f};
Ptr<data::CorpusBatch> batch_;
float invalidPathScore_;
Expr forceBatch_;
public:
DistModifier(Ptr<Options> options, Ptr<data::CorpusBatch> batch, float invalidPathScore) :
options_(options), forceDecode_(options_->get<bool>("force-decode", false)),
batch_(batch), invalidPathScore_(invalidPathScore) {
if(options_->hasAndNotEmpty("output-sampling")) {
sampling_ = true;
auto samplingOpts = options_->get<std::vector<std::string>>("output-sampling", {});
samplingMethod_ = samplingOpts.size() > 0 ? samplingOpts[0] : "full";
if(samplingMethod_ == "0") { // for backcompat with boolean values
sampling_ = false;
samplingMethod_ = "";
} else if(samplingMethod_ == "1") { // for backcompat with boolean values
sampling_ = true;
samplingMethod_ = "full";
}
if(samplingMethod_ == "full") {
if(samplingOpts.size() > 1)
temperature_ = std::stof(samplingOpts[1]);
}
if(samplingMethod_ == "topk") {
if(samplingOpts.size() > 1)
topk_ = std::stoi(samplingOpts[1]);
if(samplingOpts.size() > 2)
temperature_ = std::stof(samplingOpts[2]);
}
}
}
Expr force(Expr scores, int pos, int beamSize, std::vector<IndexType>& batchIndices) {
// we check the last field of the batch for force-decoding content
int dimTime = (int)batch_->back()->batchWidth();
if(!forceDecode_ || pos >= dimTime) // nothing to force-decode, just return original scores
return scores;
LOG_ONCE(info, "Force-decoding with given prefixes");
// if we get here, then we have to do force-decoding. We do this by "softly" modifying the scores and passing the
// result to the normal top-k/beam search. "Softly" here means we add masking terms rather than making hard selections
// which preserves the original tensor layout.
// This allows for beam-search and batched force-decoding with different length prefixes in a batch
// (way harder to do with actual index manipulation). We then return modified (masked) probabilities to the beam-search
// which then continues as normal on the modified distribution.
if(!forceBatch_) {
// turn the batch into a cached tensor that lives in the computation graph
std::vector<WordIndex> forceWords;
for(auto& word : batch_->back()->data())
forceWords.push_back(word.toWordIndex());
int dimBatch = (int)batch_->back()->batchSize();
forceBatch_ = scores->graph()->constant({1, dimTime, dimBatch, 1}, inits::fromVector(forceWords), Type::uint32); // [1, dimTime, dimBatch, 1]
}
// if we remove batch entries during decoding (finished decoding) then adjust here
if(forceBatch_->shape()[-2] != batchIndices.size())
forceBatch_ = index_select(forceBatch_, -2, batchIndices);
// get vocab index and probability for force-decoded tokens for the current time step
Expr forceIndices = slice(forceBatch_, /*axis=*/-3, pos); // [1, 1, dimBatch, 1]
Expr forceVals = gather(scores, /*axis=*/-1, forceIndices); // [1, 1, dimBatch, 1]
// create dummy indices and values for beam entries other then the force-decoded value. This is required to ensure that the beam
// does not collapse for hyps outside the forced hyps and can still do full beam-search once we finish force-decoding for a batch
// entry. We initialize randomly (they are not going to be used anyway due to very low prob) and shift by 1 to have 0 at first postion.
int dimVocab = scores->shape()[-1];
auto graph = scores->graph();
// we start at 256 to skip over suppressed special words in SentencePiece @TODO: this should be somehow inferred.
Expr dummyIndices = shift(graph->constant({1, 1, 1, beamSize}, inits::uniform(256.f, (float)dimVocab)), {0, 0, 0, 1}, 0.f);
// we use a range of invalidPathScore_ to invalidPathScore_ / 2 to make sure that the probabilities stay low, but larger than invalidPathScore_ itself.
Expr dummyVals = shift(graph->constant({1, 1, 1, beamSize}, inits::uniform(invalidPathScore_, invalidPathScore_ / 2.f)), {0, 0, 0, 1}, 0.f);
// here we add the force-decoded entries back into the zeroed positions
dummyIndices = cast(cast(dummyIndices, Type::float32) + cast(forceIndices, Type::float32), Type::uint32);
dummyVals = dummyVals + forceVals;
// create a tensor of the same size as the original logits, initialize with invalidPathScore and then scatter the force-decoded and
// dummy values into the correct positions.
Expr forcedScores = constant_like(scores, inits::fromValue(invalidPathScore_));
forcedScores = scatter(forcedScores, -1, dummyIndices, dummyVals);
// for entries that have finished force-decoding (the batch has eosId as vocab id) use the original logits for the whole batch entry
// via interpolating by a selector. In marian eosId is used for padding, so this works everywhere and eos for unfinished hyps means
// free decoding or sampling.
WordIndex eosId = batch_->back()->vocab()->getEosId().toWordIndex();
auto interpol = eq(cast(forceIndices, scores->value_type()), (float)eosId);
return interpol * scores + (1.f - interpol) * forcedScores;
}
Expr sample(Expr scores) {
if(sampling_) {
if(temperature_ != 1.f)
scores = scores / temperature_;
if(samplingMethod_ == "full") {
LOG_ONCE(info, "Output sampling from the full softmax distribution with temperature {}", temperature_);
return logsoftmax(scores + constant_like(scores, inits::gumbel()));
} else if(samplingMethod_ == "topk") {
if(topk_ == 1)
LOG_ONCE(info, "Output sampling with k=1 is equivalent to beam search with beam size 1");
LOG_ONCE(info, "Output sampling via top-{} sampling with temperature {}", topk_, temperature_);
Expr invalidLogits = constant_like(scores, inits::fromValue(invalidPathScore_));
// select top-k values
Expr val, idx;
std::tie(val, idx) = topk(scores, topk_, /*axis=*/-1, /*descending=*/true);
// Add Gumbel noise to top-k values only and compute logsoftmax, used for argmax sampling later in beam-search
Expr gumbelVal = logsoftmax(val + constant_like(val, inits::gumbel()));
// Scatter gumbelled values back into logits to fill with usable values
return scatter(invalidLogits, -1, idx, gumbelVal);
} else {
ABORT("Unknown sampling method: {}", samplingMethod_);
}
} else { // no sampling
return scores;
}
}
};
}