From 8c04f6647422ce66b9ee7e9c6ff7b1c03c5f0a80 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Tue, 15 Jun 2021 00:10:08 +0000 Subject: [PATCH] reverse batch beam argument order --- src/data/shortlist.cpp | 2 +- src/data/shortlist.h | 4 ++-- src/translator/beam_search.cpp | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/data/shortlist.cpp b/src/data/shortlist.cpp index 3eaa15f7..8bc25178 100644 --- a/src/data/shortlist.cpp +++ b/src/data/shortlist.cpp @@ -101,7 +101,7 @@ LSHShortlist::LSHShortlist(int k, int nbits) //#define BLAS_FOUND 1 -WordIndex LSHShortlist::reverseMap(int batchIdx, int beamIdx, int idx) const { +WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const { //int currBeamSize = indicesExpr_->shape()[0]; int currBatchSize = indicesExpr_->shape()[1]; idx = (k_ * currBatchSize * beamIdx) + (k_ * batchIdx) + idx; diff --git a/src/data/shortlist.h b/src/data/shortlist.h index 0ed6eae3..2b8953bd 100644 --- a/src/data/shortlist.h +++ b/src/data/shortlist.h @@ -42,7 +42,7 @@ public: Shortlist(const std::vector& indices); virtual ~Shortlist(); - virtual WordIndex reverseMap(int batchIdx, int beamIdx, int idx) const; + virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const; virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const; virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt); @@ -81,7 +81,7 @@ private: public: LSHShortlist(int k, int nbits); - virtual WordIndex reverseMap(int batchIdx, int beamIdx, int idx) const override; + virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override; virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override; virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override; diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp index b1584007..eda288a4 100644 --- a/src/translator/beam_search.cpp +++ b/src/translator/beam_search.cpp @@ -101,7 +101,7 @@ Beams BeamSearch::toHyps(const std::vector& nBestKeys, // [current // For factored decoding, the word is built over multiple decoding steps, // starting with the lemma, then adding factors one by one. if (factorGroup == 0) { - word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) currentBatchIdx, (int) prevBeamHypIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0 + word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0 std::vector factorIndices; factoredVocab->word2factors(word, factorIndices); //LOG(info, "{} + {} ({}) -> {} -> {}", // factoredVocab->decode(prevHyp->tracebackWords()), @@ -122,7 +122,7 @@ Beams BeamSearch::toHyps(const std::vector& nBestKeys, // [current } } else if (shortlist) - word = Word::fromWordIndex(shortlist->reverseMap((int) origBatchIdx, (int) prevBeamHypIdx, wordIdx)); + word = Word::fromWordIndex(shortlist->reverseMap((int) prevBeamHypIdx, (int) origBatchIdx, wordIdx)); else word = Word::fromWordIndex(wordIdx);