reverse batch beam argument order

This commit is contained in:
Hieu Hoang 2021-06-15 00:10:08 +00:00
parent dffbb47eea
commit 8c04f66474
3 changed files with 5 additions and 5 deletions

View File

@ -101,7 +101,7 @@ LSHShortlist::LSHShortlist(int k, int nbits)
//#define BLAS_FOUND 1
WordIndex LSHShortlist::reverseMap(int batchIdx, int beamIdx, int idx) const {
WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
//int currBeamSize = indicesExpr_->shape()[0];
int currBatchSize = indicesExpr_->shape()[1];
idx = (k_ * currBatchSize * beamIdx) + (k_ * batchIdx) + idx;

View File

@ -42,7 +42,7 @@ public:
Shortlist(const std::vector<WordIndex>& indices);
virtual ~Shortlist();
virtual WordIndex reverseMap(int batchIdx, int beamIdx, int idx) const;
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const;
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
@ -81,7 +81,7 @@ private:
public:
LSHShortlist(int k, int nbits);
virtual WordIndex reverseMap(int batchIdx, int beamIdx, int idx) const override;
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;

View File

@ -101,7 +101,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
// For factored decoding, the word is built over multiple decoding steps,
// starting with the lemma, then adding factors one by one.
if (factorGroup == 0) {
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) currentBatchIdx, (int) prevBeamHypIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
//LOG(info, "{} + {} ({}) -> {} -> {}",
// factoredVocab->decode(prevHyp->tracebackWords()),
@ -122,7 +122,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
}
}
else if (shortlist)
word = Word::fromWordIndex(shortlist->reverseMap((int) origBatchIdx, (int) prevBeamHypIdx, wordIdx));
word = Word::fromWordIndex(shortlist->reverseMap((int) prevBeamHypIdx, (int) origBatchIdx, wordIdx));
else
word = Word::fromWordIndex(wordIdx);