batch idx nearly there

This commit is contained in:
Hieu Hoang 2021-04-29 13:43:31 -07:00
parent 1a3e5ab58e
commit daf853e7aa
3 changed files with 23 additions and 13 deletions

View File

@ -22,9 +22,9 @@ Shortlist::Shortlist(const std::vector<WordIndex>& indices)
: indices_(indices) {}
const std::vector<WordIndex>& Shortlist::indices() const { return indices_; }
WordIndex Shortlist::reverseMap(size_t beamIdx, int idx) const { return indices_[idx]; }
WordIndex Shortlist::reverseMap(size_t batchIdx, size_t beamIdx, int idx) const { return indices_[idx]; }
WordIndex Shortlist::tryForwardMap(size_t beamIdx, WordIndex wIdx) const {
WordIndex Shortlist::tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const {
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
return (int)std::distance(indices_.begin(), first); // return coordinate if found
@ -139,13 +139,18 @@ LSHShortlist::LSHShortlist(int k, int nbits)
#define BLAS_FOUND 1
WordIndex LSHShortlist::reverseMap(size_t beamIdx, int idx) const {
idx = k_ * beamIdx + idx;
WordIndex LSHShortlist::reverseMap(size_t batchIdx, size_t beamIdx, int idx) const {
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
int currBeamSize = indicesExpr_->shape()[1];
std::cerr << "currBeamSize=" << currBeamSize << std::endl;
std::cerr << "indices_=" << indices_.size() << std::endl;
idx = (k_ * currBeamSize) * batchIdx + k_ * beamIdx + idx;
std::cerr << "idx=" << idx << std::endl;
assert(idx < indices_.size());
return indices_[idx];
}
WordIndex LSHShortlist::tryForwardMap(size_t beamIdx, WordIndex wIdx) const {
WordIndex LSHShortlist::tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const {
//utils::Debug(indices_, "LSHShortlist::tryForwardMap indices_");
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
bool found = first != indices_.end();

View File

@ -42,8 +42,8 @@ public:
Shortlist(const std::vector<WordIndex>& indices);
virtual WordIndex reverseMap(size_t beamIdx, int idx) const;
virtual WordIndex tryForwardMap(size_t beamIdx, WordIndex wIdx) const;
virtual WordIndex reverseMap(size_t batchIdx, size_t beamIdx, int idx) const;
virtual WordIndex tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
virtual Expr getIndicesExpr(int batchSize, int currBeamSize) const;
@ -75,8 +75,8 @@ private:
public:
LSHShortlist(int k, int nbits);
virtual WordIndex reverseMap(size_t beamIdx, int idx) const override;
virtual WordIndex tryForwardMap(size_t beamIdx, WordIndex wIdx) const override;
virtual WordIndex reverseMap(size_t batchIdx, size_t beamIdx, int idx) const override;
virtual WordIndex tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const override;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
virtual Expr getIndicesExpr(int batchSize,int currBeamSize) const override;

View File

@ -50,7 +50,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
const auto beamHypIdx = (key / vocabSize) % nBestBeamSize;
const auto currentBatchIdx = (key / vocabSize) / nBestBeamSize;
const auto origBatchIdx = reverseBatchIdxMap.empty() ? currentBatchIdx : reverseBatchIdxMap[currentBatchIdx]; // map currentBatchIdx back into original position within starting maximal batch size, required to find correct beam
bool dropHyp = !dropBatchEntries.empty() && dropBatchEntries[origBatchIdx] && factorGroup == 0;
WordIndex wordIdx;
@ -85,6 +84,12 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
// map wordIdx to word
auto prevBeamHypIdx = beamHypIdx; // back pointer
std::cerr << "currentBatchIdx=" << currentBatchIdx
<< " origBatchIdx=" << origBatchIdx
<< " beamHypIdx=" << beamHypIdx
<< " prevBeamHypIdx=" << prevBeamHypIdx
<< std::endl;
auto prevHyp = beam[prevBeamHypIdx];
Word word;
// If short list has been set, then wordIdx is an index into the short-listed word set,
@ -94,7 +99,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
// For factored decoding, the word is built over multiple decoding steps,
// starting with the lemma, then adding factors one by one.
if (factorGroup == 0) {
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap(prevBeamHypIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap(currentBatchIdx, prevBeamHypIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
//LOG(info, "{} + {} ({}) -> {} -> {}",
// factoredVocab->decode(prevHyp->tracebackWords()),
@ -115,7 +120,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
}
}
else if (shortlist)
word = Word::fromWordIndex(shortlist->reverseMap(prevBeamHypIdx, wordIdx));
word = Word::fromWordIndex(shortlist->reverseMap(currentBatchIdx, prevBeamHypIdx, wordIdx));
else
word = Word::fromWordIndex(wordIdx);
@ -308,7 +313,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
suppressed.erase(std::remove_if(suppressed.begin(),
suppressed.end(),
[&](WordIndex i) {
return shortlist->tryForwardMap(3343, i) == data::Shortlist::npos; // TODO beamIdx
return shortlist->tryForwardMap(4545, 3343, i) == data::Shortlist::npos; // TODO beamIdx
}),
suppressed.end());