mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-05 01:31:46 +03:00
reverse batch beam argument order
This commit is contained in:
parent
dffbb47eea
commit
8c04f66474
@ -101,7 +101,7 @@ LSHShortlist::LSHShortlist(int k, int nbits)
|
||||
|
||||
//#define BLAS_FOUND 1
|
||||
|
||||
WordIndex LSHShortlist::reverseMap(int batchIdx, int beamIdx, int idx) const {
|
||||
WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
|
||||
//int currBeamSize = indicesExpr_->shape()[0];
|
||||
int currBatchSize = indicesExpr_->shape()[1];
|
||||
idx = (k_ * currBatchSize * beamIdx) + (k_ * batchIdx) + idx;
|
||||
|
@ -42,7 +42,7 @@ public:
|
||||
Shortlist(const std::vector<WordIndex>& indices);
|
||||
virtual ~Shortlist();
|
||||
|
||||
virtual WordIndex reverseMap(int batchIdx, int beamIdx, int idx) const;
|
||||
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const;
|
||||
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
|
||||
@ -81,7 +81,7 @@ private:
|
||||
|
||||
public:
|
||||
LSHShortlist(int k, int nbits);
|
||||
virtual WordIndex reverseMap(int batchIdx, int beamIdx, int idx) const override;
|
||||
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
|
||||
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
|
||||
|
@ -101,7 +101,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
|
||||
// For factored decoding, the word is built over multiple decoding steps,
|
||||
// starting with the lemma, then adding factors one by one.
|
||||
if (factorGroup == 0) {
|
||||
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) currentBatchIdx, (int) prevBeamHypIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
|
||||
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) prevBeamHypIdx, (int) currentBatchIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
|
||||
std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
|
||||
//LOG(info, "{} + {} ({}) -> {} -> {}",
|
||||
// factoredVocab->decode(prevHyp->tracebackWords()),
|
||||
@ -122,7 +122,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
|
||||
}
|
||||
}
|
||||
else if (shortlist)
|
||||
word = Word::fromWordIndex(shortlist->reverseMap((int) origBatchIdx, (int) prevBeamHypIdx, wordIdx));
|
||||
word = Word::fromWordIndex(shortlist->reverseMap((int) prevBeamHypIdx, (int) origBatchIdx, wordIdx));
|
||||
else
|
||||
word = Word::fromWordIndex(wordIdx);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user