Make Vocab const in beam search. Remove some trailing whitespace.

This commit is contained in:
Ulrich Germann 2020-01-29 16:25:56 +00:00
parent cfdde151a1
commit 7228698b06

View File

@ -16,7 +16,7 @@ private:
Ptr<Options> options_;
std::vector<Ptr<Scorer>> scorers_;
size_t beamSize_;
Ptr<Vocab> trgVocab_;
Ptr<const Vocab> trgVocab_;
static constexpr auto INVALID_PATH_SCORE = -9999; // (@TODO: change to -9999.0 once C++ allows that)
static constexpr auto PURGE_BATCH = true; // @TODO: diagnostic, to-be-removed once confirmed there are no issues.
@ -24,7 +24,7 @@ private:
public:
BeamSearch(Ptr<Options> options,
const std::vector<Ptr<Scorer>>& scorers,
Ptr<Vocab> trgVocab)
const Ptr<const Vocab> trgVocab)
: options_(options),
scorers_(scorers),
beamSize_(options_->get<size_t>("beam-size")),
@ -42,8 +42,8 @@ public:
const std::vector<bool>& dropBatchEntries, // [origDimBatch] - empty source batch entries are marked with true, should be cleared after first use.
const std::vector<IndexType>& batchIdxMap) const { // [origBatchIdx -> currentBatchIdx]
std::vector<float> align; // collects alignment information from the last executed time step
if(options_->hasAndNotEmpty("alignment") && factorGroup == 0)
align = scorers_[0]->getAlignment(); // [beam depth * max src length * current batch size] -> P(s|t); use alignments from the first scorer, even if ensemble,
if(options_->hasAndNotEmpty("alignment") && factorGroup == 0)
align = scorers_[0]->getAlignment(); // [beam depth * max src length * current batch size] -> P(s|t); use alignments from the first scorer, even if ensemble,
const auto origDimBatch = beams.size(); // see function search for definition of origDimBatch and currentDimBatch etc.
Beams newBeams(origDimBatch); // return value of this function goes here. There are always origDimBatch beams.
@ -56,7 +56,7 @@ public:
reverseBatchIdxMap.resize(batchIdxMap.size()); // adjust size if doing batch purging.
currentDimBatch = 0;
for(int i = 0; i < batchIdxMap.size(); ++i) {
reverseBatchIdxMap[batchIdxMap[i]] = i; // reverse batch index mapping, multiple occurences get overwritten with the last one,
reverseBatchIdxMap[batchIdxMap[i]] = i; // reverse batch index mapping, multiple occurences get overwritten with the last one,
// which is expected due to down-shifting
if(!beams[i].empty())
currentDimBatch++;
@ -143,12 +143,12 @@ public:
auto lval = states[j]->getLogProbs().getFactoredLogitsTensor(factorGroup); // [maxBeamSize, 1, currentDimBatch, dimFactorVocab]
// The flatting happens based on actual (current) batch size and batch index computed with batch-pruning as we are looking into the pruned tensor
size_t flattenedLogitIndex = (beamHypIdx * currentDimBatch + currentBatchIdx) * vocabSize + wordIdx; // (beam idx, batch idx, word idx); note: beam and batch are transposed, compared to 'key'
// @TODO: use a function on shape() to index, or new method val->at({i1, i2, i3, i4}) with broadcasting
ABORT_IF(lval->shape() != Shape({(int)nBestBeamSize, 1, (int)currentDimBatch, (int)vocabSize}) &&
(beamHypIdx == 0 && lval->shape() != Shape({1, 1, (int)currentDimBatch, (int)vocabSize})),
"Unexpected shape of logits?? {} != {}", lval->shape(), Shape({(int)nBestBeamSize, 1, (int)currentDimBatch, (int)vocabSize}));
breakDown[j] += lval->get(flattenedLogitIndex);
}
hyp->setScoreBreakdown(breakDown);
@ -162,7 +162,7 @@ public:
newBeam.push_back(hyp);
}
// if factored vocab and this is not the first factor, we need to
// also propagate factored hypotheses that do not get expanded in this step because they don't have this factor
if (factorGroup > 0) {
@ -214,7 +214,7 @@ public:
// in a single beam, i.e.:
// * [word1-batch1, word1-batch2, ..., word2-batch1, ...]
//
size_t origDimBatch = batch->size(); // number of sentences in batch
size_t batchWidth = batch->width(); // max src length
@ -243,7 +243,7 @@ public:
for(auto beam : beams) {
Beam newBeam; // a beam of surviving hyps
for(auto hyp : beam)
if(hyp->getWord() != trgEosId) // if this hyp is not finished,
if(hyp->getWord() != trgEosId) // if this hyp is not finished,
newBeam.push_back(hyp); // move over to beam of surviving hyps
if(PURGE_BATCH)
@ -298,8 +298,8 @@ public:
// create one beam per batch entry with sentence-start hypothesis
Beams beams(origDimBatch, Beam(beamSize_, Hypothesis::New())); // array [origDimBatch] of array [maxBeamSize] of Hypothesis, keeps full size through search.
// batch purging is determined from an empty sub-beam.
std::vector<IndexType> batchIdxMap(origDimBatch); // Record at which batch entry a beam is looking.
// By default that corresponds to position in array,
std::vector<IndexType> batchIdxMap(origDimBatch); // Record at which batch entry a beam is looking.
// By default that corresponds to position in array,
// but shifts in the course of removing batch entries when they are finished.
const std::vector<bool> emptyBatchEntries; // used for recording if there are empty input batch entries
@ -359,7 +359,7 @@ public:
std::vector<IndexType> hypIndices; // [maxBeamSize, 1, currentDimBatch, 1] (flattened) tensor index ((beamHypIdx, batchIdx), flattened) of prev hyp that a hyp originated from
std::vector<Word> prevWords; // [maxBeamSize, 1, currentDimBatch, 1] (flattened) word that a hyp ended in, for advancing the decoder-model's history
Expr prevPathScores; // [maxBeamSize, 1, currentDimBatch, 1], path score that a hyp ended in (last axis will broadcast into vocab size when adding expandedPathScores)
bool anyCanExpand = false; // stays false if all hyps are invalid factor expansions
if(t == 0 && factorGroup == 0) { // no scores yet
prevPathScores = graph->constant({1, 1, 1, 1}, inits::fromValue(0));
@ -373,7 +373,7 @@ public:
for(int currentBatchIdx = 0; currentBatchIdx < beams.size(); ++currentBatchIdx) // loop over batch entries (active sentences)
if(!beams[currentBatchIdx].empty() || !PURGE_BATCH) // for each beam check
batchIndices.push_back(prevBatchIdxMap[currentBatchIdx]); // which batch entries were active in previous step
std::vector<float> prevScores;
for(size_t beamHypIdx = 0; beamHypIdx < maxBeamSize; ++beamHypIdx) { // loop over globally maximal beam-size (maxBeamSize)
for(int origBatchIdx = 0; origBatchIdx < origDimBatch; ++origBatchIdx) { // loop over all batch entries (active and inactive)
@ -390,11 +390,11 @@ public:
if(factorGroup == 0)
currentBatchIdx = prevBatchIdxMap[origBatchIdx]; // subselection may happen for factorGroup == 0
else
currentBatchIdx = batchIdxMap[origBatchIdx]; // no subselection happens for factorGroup > 0,
// but we treat it like a next step, since a step
currentBatchIdx = batchIdxMap[origBatchIdx]; // no subselection happens for factorGroup > 0,
// but we treat it like a next step, since a step
// happened for factorGroup == 0
}
auto hypIndex = (IndexType)(hyp->getPrevStateIndex() * currentDimBatch + currentBatchIdx); // (beamHypIdx, batchIdx), flattened, for index_select() operation
hypIndices.push_back(hypIndex); // (beamHypIdx, batchIdx), flattened as said above.
@ -409,7 +409,7 @@ public:
}
}
}
if(factorGroup == 0)
if(factorGroup == 0)
currentDimBatch = (IndexType) batchIndices.size(); // keep batch size constant for all factor groups in a time step
prevPathScores = graph->constant({(int)maxBeamSize, 1, (int)currentDimBatch, 1}, inits::fromVector(prevScores));
}
@ -494,7 +494,7 @@ public:
beams,
states, // used for keeping track of per-ensemble-member path score
batch, // only used for propagating alignment info
factoredVocab, factorGroup,
factoredVocab, factorGroup,
emptyBatchEntries, // [origDimBatch] - empty source batch entries are marked with true
batchIdxMap); // used to create a reverse batch index map to recover original batch indices for this step
} // END FOR factorGroup = 0 .. numFactorGroups-1