mirror of
https://github.com/marian-nmt/marian.git
synced 2024-10-05 19:17:10 +03:00
Merged PR 23767: More principled sampling and force-decoding
This PR adds correct force-decoding and more principled sampling, both should now work for ensembles, batches and with beam search.
This commit is contained in:
parent
e13053a6f2
commit
76964791ad
@ -681,6 +681,11 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
|
||||
cli.add<std::string>("--alignment",
|
||||
"Return word alignment. Possible values: 0.0-1.0, hard, soft")
|
||||
->implicit_val("1");
|
||||
cli.add<bool>("--force-decode",
|
||||
"Use force-decoding of given prefixes. Forces decoding to follow vocab IDs from last stream in the batch (or the first stream, if there is only one). "
|
||||
"Use either as `./marian-decoder --force-decode --input source.txt prefixes.txt [...]` where inputs and prefixes align on line-level or as "
|
||||
"`paste source.txt prefixes.txt | ./marian-decoder --force-decode --tsv --tsv-fields 2 [...]` when reading from stdin."
|
||||
);
|
||||
cli.add<bool>("--word-scores",
|
||||
"Print word-level scores. One score per subword unit, not normalized even if --normalize");
|
||||
cli.add<std::string/*SchedulerPeriod*/>("--stat-freq",
|
||||
@ -709,9 +714,10 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
|
||||
cli.add<std::vector<float>>("--weights",
|
||||
"Scorer weights");
|
||||
cli.add<std::vector<std::string>>("--output-sampling",
|
||||
"Noise output layer with gumbel noise. Implicit default is 'full' for sampling from full distribution. "
|
||||
" Also accepts 'topk num' (e.g. topk 100) for top-100 sampling.")
|
||||
->implicit_val("full");
|
||||
"Noise output layer with gumbel noise. Implicit default is 'full 1.0' for sampling from full distribution"
|
||||
" with softmax temperature 1.0. Also accepts 'topk num temp' (e.g. topk 100 0.1) for top-100 sampling with"
|
||||
" temperature 0.1")
|
||||
->implicit_val("full 1.0");
|
||||
cli.add<std::vector<int>>("--output-approx-knn",
|
||||
"Use approximate knn search in output layer (currently only in transformer)")
|
||||
->implicit_val("100 1024");
|
||||
|
@ -347,11 +347,18 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate, size_t seed)
|
||||
|
||||
auto vocabDims = options_->get<std::vector<int>>("dim-vocabs");
|
||||
vocabDims.resize(numVocs, 0);
|
||||
for(size_t i = 0; i + 1 < numVocs; ++i) {
|
||||
|
||||
// when force-decoding we want the last vocab to be part of the batch,
|
||||
// hence we do not drop it from the input batch.
|
||||
bool forceDecoding = options_->get<bool>("force-decode", false);
|
||||
size_t shift = !forceDecoding ? 1 : 0;
|
||||
|
||||
for(size_t i = 0; i + shift < numVocs; ++i) {
|
||||
Ptr<Vocab> vocab = New<Vocab>(options_, i);
|
||||
vocabDims[i] = (int) vocab->load(vocabPaths[i], maxVocabs[i]);
|
||||
vocabs_.emplace_back(vocab);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: As above, this is not nice as it modifies the option object and needs to expose the changes
|
||||
// outside the corpus as models need to know about the vocabulary size; extract the vocab
|
||||
// creation functionality from the class.
|
||||
@ -368,10 +375,11 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate, size_t seed)
|
||||
}
|
||||
}
|
||||
|
||||
ABORT_IF(!tsv_ && vocabs_.size() != files_.size(),
|
||||
size_t numStreams = files_.size();
|
||||
ABORT_IF(!tsv_ && vocabs_.size() != numStreams,
|
||||
"Number of {} files ({}) and vocab files ({}) does not agree",
|
||||
training ? "corpus" : "input",
|
||||
files_.size(),
|
||||
numStreams,
|
||||
vocabs_.size());
|
||||
|
||||
// Handle guided alignment and data weighting files. Alignments and weights in TSV input were
|
||||
|
@ -1040,7 +1040,7 @@ struct GatherNodeOp : public NaryNodeOp {
|
||||
NodeOps backwardOps() override {
|
||||
return {NodeOp(
|
||||
// @TODO: rename to scatter
|
||||
Insert</*add=*/true>(child(0)->grad(), adj_, child(1)->val(), axis_))};
|
||||
Insert</*add=*/true>(child(0)->grad(), adj_, /*indices=*/child(1)->val(), axis_))};
|
||||
}
|
||||
|
||||
Shape newShape(Expr a, int axis, Expr indices) {
|
||||
@ -1097,7 +1097,7 @@ struct ScatterNodeOp : public NaryNodeOp {
|
||||
NodeOps forwardOps() override {
|
||||
return {NodeOp(
|
||||
CopyCast(val_, child(0)->val()); // @TODO: use normal copy
|
||||
Insert</*add=*/false>(val_, child(2)->val(), child(1)->val(), axis_)
|
||||
Insert</*add=*/false>(val_, /*source=*/child(2)->val(), /*indices=*/child(1)->val(), axis_)
|
||||
)};
|
||||
}
|
||||
|
||||
@ -1107,7 +1107,7 @@ struct ScatterNodeOp : public NaryNodeOp {
|
||||
|
||||
Shape newShape(Expr a, int axis, Expr indices, Expr source) {
|
||||
ABORT_IF(axis != -1, "only last dimensions");
|
||||
ABORT_IF(indices->shape() != source->shape(), "Shapes must match");
|
||||
// ABORT_IF(indices->shape() != source->shape(), "Shapes must match"); or broadcast
|
||||
|
||||
Shape shape = a->shape();
|
||||
// @TODO: do proper checking
|
||||
|
@ -10,40 +10,5 @@ Ptr<DecoderState> LogSoftmaxStep::apply(Ptr<DecoderState> state) {
|
||||
return state;
|
||||
}
|
||||
|
||||
Ptr<DecoderState> GumbelSoftmaxStep::apply(Ptr<DecoderState> state) {
|
||||
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
|
||||
[](Expr logits) { // lemma gets gumbelled
|
||||
return logsoftmax(logits + constant_like(logits, inits::gumbel()));
|
||||
},
|
||||
logsoftmax)); // factors don't
|
||||
return state;
|
||||
}
|
||||
|
||||
TopkGumbelSoftmaxStep::TopkGumbelSoftmaxStep(int k) : k_{k} {}
|
||||
|
||||
Ptr<DecoderState> TopkGumbelSoftmaxStep::apply(Ptr<DecoderState> state) {
|
||||
state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
|
||||
[=](Expr logits) { // lemma gets gumbelled
|
||||
// create logits-sized tensor consisting only of invalid path scores
|
||||
float invalidPathScore = NumericLimits<float>(logits->value_type()).lowest;
|
||||
Expr invalidLogits = constant_like(logits, inits::fromValue(invalidPathScore));
|
||||
|
||||
// select top-k values
|
||||
Expr val, idx;
|
||||
std::tie(val, idx) = topk(logits, k_, /*axis=*/-1, /*descending=*/true);
|
||||
|
||||
// uncomment below to display probability mass in top-k selection
|
||||
// debug(sum(gather(softmax(logits), -1, idx), -1), "sum");
|
||||
|
||||
// Add Gumbel noise to top-k values only and compute logsoftmax, used for argmax sampling later in beam-search
|
||||
Expr gumbelVal = logsoftmax(val + constant_like(val, inits::gumbel()));
|
||||
|
||||
// Scatter gumbelled values back into logits to fill with usable values
|
||||
return scatter(invalidLogits, -1, idx, gumbelVal);
|
||||
},
|
||||
logsoftmax)); // factors don't
|
||||
return state;
|
||||
}
|
||||
|
||||
} // namespace models
|
||||
} // namespace marian
|
||||
|
@ -297,32 +297,6 @@ public:
|
||||
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
|
||||
};
|
||||
|
||||
// Gumbel-max noising for sampling during translation.
|
||||
// Produces accurate sampling with beam=1. Turn on
|
||||
// with --output-sampling [full] during translation
|
||||
// with marian-decoder for samnpling from the full
|
||||
// softmax distribution.
|
||||
class GumbelSoftmaxStep : public ILogProbStep {
|
||||
public:
|
||||
virtual ~GumbelSoftmaxStep() {}
|
||||
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
|
||||
};
|
||||
|
||||
|
||||
// Gumbel-max noising for top-k sampling during translation.
|
||||
// Produces accurate sampling with beam=1. Turn on
|
||||
// with --output-sampling topk [10] during translation
|
||||
// with marian-decoder for top-10 sampling.
|
||||
class TopkGumbelSoftmaxStep : public ILogProbStep {
|
||||
private:
|
||||
int k_{1};
|
||||
|
||||
public:
|
||||
TopkGumbelSoftmaxStep(int k);
|
||||
virtual ~TopkGumbelSoftmaxStep() {}
|
||||
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override;
|
||||
};
|
||||
|
||||
// class to wrap an IEncoderDecoder and a ILogProbStep that are executed in sequence,
|
||||
// wrapped again in the IEncoderDecoder interface
|
||||
// @TODO: seems we are conflating an interface defition with its implementation?
|
||||
|
@ -370,28 +370,7 @@ Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
|
||||
// add (log)softmax if requested
|
||||
if (use == usage::translation) {
|
||||
if(std::dynamic_pointer_cast<EncoderDecoder>(baseModel)) {
|
||||
if(options->hasAndNotEmpty("output-sampling")) {
|
||||
auto sampling = options->get<std::vector<std::string>>("output-sampling", {});
|
||||
std::string method = sampling.size() > 0 ? sampling[0] : "full";
|
||||
|
||||
if(method == "0") { /*for backwards-compat when output-sampling: false in yaml file*/
|
||||
// do normal decoding
|
||||
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
|
||||
} else if(method == "full" || method == "1" /*for backwards-compat when output-sampling: true in yaml file*/) {
|
||||
LOG(info, "Output sampling from the full softmax distribution");
|
||||
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
|
||||
} else if(method == "topk") {
|
||||
int k = sampling.size() > 1 ? std::stoi(sampling[1]) : 10;
|
||||
if(k == 1)
|
||||
LOG(info, "Output sampling with k=1 is equivalent to beam search with beam size 1");
|
||||
LOG(info, "Output sampling via top-{} sampling", k);
|
||||
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<TopkGumbelSoftmaxStep>(k));
|
||||
} else {
|
||||
ABORT("Unknown sampling method: {}", method);
|
||||
}
|
||||
} else {
|
||||
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
|
||||
}
|
||||
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
|
||||
}
|
||||
#ifdef COMPILE_EXAMPLES
|
||||
// note: 'usage::translation' here means 'inference'
|
||||
|
@ -199,6 +199,8 @@ void CopyCastFrom(Tensor out, const T* in, int length) {
|
||||
#endif
|
||||
} else if(out->type() == Type::float64) {
|
||||
CopyCastTo<add>(out->data<double>(), in, length);
|
||||
} else if(out->type() == Type::uint32) {
|
||||
CopyCastTo<add>(out->data<uint32_t>(), in, length);
|
||||
} else {
|
||||
ABORT("CopyCastTo to type {} not implemented", out->type());
|
||||
}
|
||||
@ -313,6 +315,8 @@ void Concatenate1(Tensor out, const std::vector<Tensor>& inputs) {
|
||||
} else if(out->type() == Type::float16) {
|
||||
gInsertCols<false><<<blocks, threads>>>(out->data<half>(), in->data<half>(), rows, cols_in, cols_out, cols_in, offset, 0);
|
||||
#endif
|
||||
} else if(out->type() == Type::uint32) {
|
||||
gInsertCols<false><<<blocks, threads>>>(out->data<uint32_t>(), in->data<uint32_t>(), rows, cols_in, cols_out, cols_in, offset, 0);
|
||||
} else {
|
||||
ABORT("Concatenate1 not implemented for type {}", out->type());
|
||||
}
|
||||
@ -392,6 +396,14 @@ void Concatenate2(Tensor out, Tensor in1, Tensor in2) {
|
||||
in2->data<half>(),
|
||||
rowStride2);
|
||||
#endif
|
||||
} else if(out->type() == Type::uint32) {
|
||||
gJoin2<<<blocks, threads>>>(out->data<uint32_t>(),
|
||||
rowBatch,
|
||||
cols,
|
||||
in1->data<uint32_t>(),
|
||||
rowStride1,
|
||||
in2->data<uint32_t>(),
|
||||
rowStride2);
|
||||
} else {
|
||||
ABORT("Concatenate2 not implemented for type {}", out->type());
|
||||
}
|
||||
|
@ -1,10 +1,9 @@
|
||||
#include "translator/beam_search.h"
|
||||
|
||||
#include "data/factored_vocab.h"
|
||||
#include "translator/helpers.h"
|
||||
#include "translator/nth_element.h"
|
||||
#include "data/shortlist.h"
|
||||
#include "common/utils.h"
|
||||
#include "data/factored_vocab.h"
|
||||
#include "data/shortlist.h"
|
||||
#include "translator/beam_search.h"
|
||||
#include "translator/helpers.h"
|
||||
#include "translator/sampling.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
@ -316,6 +315,8 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
|
||||
suppressedWordIndices = graph->indices(suppressed);
|
||||
}
|
||||
|
||||
auto distMod = New<DistModifier>(options_, batch, INVALID_PATH_SCORE);
|
||||
|
||||
// the decoding process updates the following state information in each output time step:
|
||||
// - beams: array [origDimBatch] of array [maxBeamSize] of Hypothesis
|
||||
// - current output time step's set of active hypotheses, aka active search space
|
||||
@ -413,9 +414,9 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
|
||||
|
||||
//**********************************************************************
|
||||
// compute expanded path scores with word prediction probs from all scorers
|
||||
auto expandedPathScores = prevPathScores; // will become [maxBeamSize, 1, currDimBatch, dimVocab]
|
||||
Expr logProbs;
|
||||
Expr stepScores;
|
||||
for(size_t i = 0; i < scorers_.size(); ++i) {
|
||||
Expr logProbs;
|
||||
if (factorGroup == 0) {
|
||||
// compute output probabilities for current output time step
|
||||
// - uses hypIndices[index in beam, 1, batch index, 1] to reorder scorer state to reflect the top-N in beams[][]
|
||||
@ -449,10 +450,19 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
|
||||
logProbs = states[i]->getLogProbs().getFactoredLogits(factorGroup, /*shortlist=*/ nullptr, hypIndices, maxBeamSize); // [maxBeamSize, 1, currentDimBatch, dimVocab]
|
||||
}
|
||||
// expand all hypotheses, [maxBeamSize, 1, currentDimBatch, 1] -> [maxBeamSize, 1, currentDimBatch, dimVocab]
|
||||
expandedPathScores = expandedPathScores + scorers_[i]->getWeight() * logProbs;
|
||||
if(i == 0)
|
||||
stepScores = scorers_[i]->getWeight() * logProbs;
|
||||
else
|
||||
stepScores = stepScores + scorers_[i]->getWeight() * logProbs;
|
||||
}
|
||||
|
||||
if(factorGroup == 0) {
|
||||
stepScores = distMod->force(stepScores, (int)t, (int)maxBeamSize, batchIndices);
|
||||
stepScores = distMod->sample(stepScores);
|
||||
}
|
||||
|
||||
// make beams continuous
|
||||
auto expandedPathScores = prevPathScores + stepScores; // will become [maxBeamSize, 1, currDimBatch, dimVocab]
|
||||
expandedPathScores = swapAxes(expandedPathScores, 0, 2); // -> [currentDimBatch, 1, maxBeamSize, dimVocab]
|
||||
|
||||
// perform NN computation
|
||||
@ -463,6 +473,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
|
||||
|
||||
//**********************************************************************
|
||||
// suppress specific symbols if not at right positions
|
||||
// @TODO: move this to DistributionModifier
|
||||
if(suppressedWordIndices && factorGroup == 0)
|
||||
suppressWords(expandedPathScores, suppressedWordIndices);
|
||||
|
||||
@ -477,6 +488,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
|
||||
/*out*/ nBestPathScores,
|
||||
/*out*/ nBestKeys,
|
||||
/*first=*/t == 0 && factorGroup == 0); // @TODO: this is only used for checking presently, and should be removed altogether
|
||||
|
||||
// Now, nBestPathScores contain N-best expandedPathScores for each batch and beam,
|
||||
// and nBestKeys for each their original location (batchIdx, beamHypIdx, word).
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include "marian.h"
|
||||
#include "translator/history.h"
|
||||
#include "translator/scorers.h"
|
||||
#include "translator/nth_element.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
|
141
src/translator/sampling.h
Normal file
141
src/translator/sampling.h
Normal file
@ -0,0 +1,141 @@
|
||||
namespace marian {
|
||||
|
||||
class DistModifier {
|
||||
private:
|
||||
Ptr<Options> options_;
|
||||
bool forceDecode_{false};
|
||||
bool sampling_{false};
|
||||
std::string samplingMethod_;
|
||||
int topk_{10};
|
||||
float temperature_{1.f};
|
||||
|
||||
Ptr<data::CorpusBatch> batch_;
|
||||
float invalidPathScore_;
|
||||
|
||||
Expr forceBatch_;
|
||||
|
||||
public:
|
||||
DistModifier(Ptr<Options> options, Ptr<data::CorpusBatch> batch, float invalidPathScore) :
|
||||
options_(options), forceDecode_(options_->get<bool>("force-decode", false)),
|
||||
batch_(batch), invalidPathScore_(invalidPathScore) {
|
||||
|
||||
if(options_->hasAndNotEmpty("output-sampling")) {
|
||||
sampling_ = true;
|
||||
auto samplingOpts = options_->get<std::vector<std::string>>("output-sampling", {});
|
||||
samplingMethod_ = samplingOpts.size() > 0 ? samplingOpts[0] : "full";
|
||||
if(samplingMethod_ == "0") { // for backcompat with boolean values
|
||||
sampling_ = false;
|
||||
samplingMethod_ = "";
|
||||
} else if(samplingMethod_ == "1") { // for backcompat with boolean values
|
||||
sampling_ = true;
|
||||
samplingMethod_ = "full";
|
||||
}
|
||||
|
||||
if(samplingMethod_ == "full") {
|
||||
if(samplingOpts.size() > 1)
|
||||
temperature_ = std::stof(samplingOpts[1]);
|
||||
}
|
||||
|
||||
if(samplingMethod_ == "topk") {
|
||||
if(samplingOpts.size() > 1)
|
||||
topk_ = std::stoi(samplingOpts[1]);
|
||||
if(samplingOpts.size() > 2)
|
||||
temperature_ = std::stof(samplingOpts[2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Expr force(Expr scores, int pos, int beamSize, std::vector<IndexType>& batchIndices) {
|
||||
// we check the last field of the batch for force-decoding content
|
||||
int dimTime = (int)batch_->back()->batchWidth();
|
||||
if(!forceDecode_ || pos >= dimTime) // nothing to force-decode, just return original scores
|
||||
return scores;
|
||||
|
||||
LOG_ONCE(info, "Force-decoding with given prefixes");
|
||||
// if we get here, then we have to do force-decoding. We do this by "softly" modifying the scores and passing the
|
||||
// result to the normal top-k/beam search. "Softly" here means we add masking terms rather than making hard selections
|
||||
// which preserves the original tensor layout.
|
||||
// This allows for beam-search and batched force-decoding with different length prefixes in a batch
|
||||
// (way harder to do with actual index manipulation). We then return modified (masked) probabilities to the beam-search
|
||||
// which then continues as normal on the modified distribution.
|
||||
|
||||
if(!forceBatch_) {
|
||||
// turn the batch into a cached tensor that lives in the computation graph
|
||||
std::vector<WordIndex> forceWords;
|
||||
for(auto& word : batch_->back()->data())
|
||||
forceWords.push_back(word.toWordIndex());
|
||||
|
||||
int dimBatch = (int)batch_->back()->batchSize();
|
||||
forceBatch_ = scores->graph()->constant({1, dimTime, dimBatch, 1}, inits::fromVector(forceWords), Type::uint32); // [1, dimTime, dimBatch, 1]
|
||||
}
|
||||
|
||||
// if we remove batch entries during decoding (finished decoding) then adjust here
|
||||
if(forceBatch_->shape()[-2] != batchIndices.size())
|
||||
forceBatch_ = index_select(forceBatch_, -2, batchIndices);
|
||||
|
||||
// get vocab index and probability for force-decoded tokens for the current time step
|
||||
Expr forceIndices = slice(forceBatch_, /*axis=*/-3, pos); // [1, 1, dimBatch, 1]
|
||||
Expr forceVals = gather(scores, /*axis=*/-1, forceIndices); // [1, 1, dimBatch, 1]
|
||||
|
||||
// create dummy indices and values for beam entries other then the force-decoded value. This is required to ensure that the beam
|
||||
// does not collapse for hyps outside the forced hyps and can still do full beam-search once we finish force-decoding for a batch
|
||||
// entry. We initialize randomly (they are not going to be used anyway due to very low prob) and shift by 1 to have 0 at first postion.
|
||||
int dimVocab = scores->shape()[-1];
|
||||
auto graph = scores->graph();
|
||||
// we start at 256 to skip over suppressed special words in SentencePiece @TODO: this should be somehow inferred.
|
||||
Expr dummyIndices = shift(graph->constant({1, 1, 1, beamSize}, inits::uniform(256.f, (float)dimVocab)), {0, 0, 0, 1}, 0.f);
|
||||
// we use a range of invalidPathScore_ to invalidPathScore_ / 2 to make sure that the probabilities stay low, but larger than invalidPathScore_ itself.
|
||||
Expr dummyVals = shift(graph->constant({1, 1, 1, beamSize}, inits::uniform(invalidPathScore_, invalidPathScore_ / 2.f)), {0, 0, 0, 1}, 0.f);
|
||||
|
||||
// here we add the force-decoded entries back into the zeroed positions
|
||||
dummyIndices = cast(cast(dummyIndices, Type::float32) + cast(forceIndices, Type::float32), Type::uint32);
|
||||
dummyVals = dummyVals + forceVals;
|
||||
|
||||
// create a tensor of the same size as the original logits, initialize with invalidPathScore and then scatter the force-decoded and
|
||||
// dummy values into the correct positions.
|
||||
Expr forcedScores = constant_like(scores, inits::fromValue(invalidPathScore_));
|
||||
forcedScores = scatter(forcedScores, -1, dummyIndices, dummyVals);
|
||||
|
||||
// for entries that have finished force-decoding (the batch has eosId as vocab id) use the original logits for the whole batch entry
|
||||
// via interpolating by a selector. In marian eosId is used for padding, so this works everywhere and eos for unfinished hyps means
|
||||
// free decoding or sampling.
|
||||
WordIndex eosId = batch_->back()->vocab()->getEosId().toWordIndex();
|
||||
auto interpol = eq(cast(forceIndices, scores->value_type()), (float)eosId);
|
||||
return interpol * scores + (1.f - interpol) * forcedScores;
|
||||
}
|
||||
|
||||
Expr sample(Expr scores) {
|
||||
if(sampling_) {
|
||||
if(temperature_ != 1.f)
|
||||
scores = scores / temperature_;
|
||||
|
||||
if(samplingMethod_ == "full") {
|
||||
LOG_ONCE(info, "Output sampling from the full softmax distribution with temperature {}", temperature_);
|
||||
return logsoftmax(scores + constant_like(scores, inits::gumbel()));
|
||||
} else if(samplingMethod_ == "topk") {
|
||||
if(topk_ == 1)
|
||||
LOG_ONCE(info, "Output sampling with k=1 is equivalent to beam search with beam size 1");
|
||||
LOG_ONCE(info, "Output sampling via top-{} sampling with temperature {}", topk_, temperature_);
|
||||
|
||||
Expr invalidLogits = constant_like(scores, inits::fromValue(invalidPathScore_));
|
||||
|
||||
// select top-k values
|
||||
Expr val, idx;
|
||||
std::tie(val, idx) = topk(scores, topk_, /*axis=*/-1, /*descending=*/true);
|
||||
|
||||
// Add Gumbel noise to top-k values only and compute logsoftmax, used for argmax sampling later in beam-search
|
||||
Expr gumbelVal = logsoftmax(val + constant_like(val, inits::gumbel()));
|
||||
|
||||
// Scatter gumbelled values back into logits to fill with usable values
|
||||
return scatter(invalidLogits, -1, idx, gumbelVal);
|
||||
} else {
|
||||
ABORT("Unknown sampling method: {}", samplingMethod_);
|
||||
}
|
||||
} else { // no sampling
|
||||
return scores;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user