mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
batch idx nearly there
This commit is contained in:
parent
1a3e5ab58e
commit
daf853e7aa
@ -22,9 +22,9 @@ Shortlist::Shortlist(const std::vector<WordIndex>& indices)
|
||||
: indices_(indices) {}
|
||||
|
||||
const std::vector<WordIndex>& Shortlist::indices() const { return indices_; }
|
||||
WordIndex Shortlist::reverseMap(size_t beamIdx, int idx) const { return indices_[idx]; }
|
||||
WordIndex Shortlist::reverseMap(size_t batchIdx, size_t beamIdx, int idx) const { return indices_[idx]; }
|
||||
|
||||
WordIndex Shortlist::tryForwardMap(size_t beamIdx, WordIndex wIdx) const {
|
||||
WordIndex Shortlist::tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const {
|
||||
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
|
||||
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
|
||||
return (int)std::distance(indices_.begin(), first); // return coordinate if found
|
||||
@ -139,13 +139,18 @@ LSHShortlist::LSHShortlist(int k, int nbits)
|
||||
|
||||
#define BLAS_FOUND 1
|
||||
|
||||
WordIndex LSHShortlist::reverseMap(size_t beamIdx, int idx) const {
|
||||
idx = k_ * beamIdx + idx;
|
||||
WordIndex LSHShortlist::reverseMap(size_t batchIdx, size_t beamIdx, int idx) const {
|
||||
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
|
||||
int currBeamSize = indicesExpr_->shape()[1];
|
||||
std::cerr << "currBeamSize=" << currBeamSize << std::endl;
|
||||
std::cerr << "indices_=" << indices_.size() << std::endl;
|
||||
idx = (k_ * currBeamSize) * batchIdx + k_ * beamIdx + idx;
|
||||
std::cerr << "idx=" << idx << std::endl;
|
||||
assert(idx < indices_.size());
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
WordIndex LSHShortlist::tryForwardMap(size_t beamIdx, WordIndex wIdx) const {
|
||||
WordIndex LSHShortlist::tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const {
|
||||
//utils::Debug(indices_, "LSHShortlist::tryForwardMap indices_");
|
||||
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
|
||||
bool found = first != indices_.end();
|
||||
|
@ -42,8 +42,8 @@ public:
|
||||
|
||||
Shortlist(const std::vector<WordIndex>& indices);
|
||||
|
||||
virtual WordIndex reverseMap(size_t beamIdx, int idx) const;
|
||||
virtual WordIndex tryForwardMap(size_t beamIdx, WordIndex wIdx) const;
|
||||
virtual WordIndex reverseMap(size_t batchIdx, size_t beamIdx, int idx) const;
|
||||
virtual WordIndex tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
|
||||
virtual Expr getIndicesExpr(int batchSize, int currBeamSize) const;
|
||||
@ -75,8 +75,8 @@ private:
|
||||
|
||||
public:
|
||||
LSHShortlist(int k, int nbits);
|
||||
virtual WordIndex reverseMap(size_t beamIdx, int idx) const override;
|
||||
virtual WordIndex tryForwardMap(size_t beamIdx, WordIndex wIdx) const override;
|
||||
virtual WordIndex reverseMap(size_t batchIdx, size_t beamIdx, int idx) const override;
|
||||
virtual WordIndex tryForwardMap(size_t batchIdx, size_t beamIdx, WordIndex wIdx) const override;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
|
||||
virtual Expr getIndicesExpr(int batchSize,int currBeamSize) const override;
|
||||
|
@ -50,7 +50,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
|
||||
const auto beamHypIdx = (key / vocabSize) % nBestBeamSize;
|
||||
const auto currentBatchIdx = (key / vocabSize) / nBestBeamSize;
|
||||
const auto origBatchIdx = reverseBatchIdxMap.empty() ? currentBatchIdx : reverseBatchIdxMap[currentBatchIdx]; // map currentBatchIdx back into original position within starting maximal batch size, required to find correct beam
|
||||
|
||||
bool dropHyp = !dropBatchEntries.empty() && dropBatchEntries[origBatchIdx] && factorGroup == 0;
|
||||
|
||||
WordIndex wordIdx;
|
||||
@ -85,6 +84,12 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
|
||||
|
||||
// map wordIdx to word
|
||||
auto prevBeamHypIdx = beamHypIdx; // back pointer
|
||||
std::cerr << "currentBatchIdx=" << currentBatchIdx
|
||||
<< " origBatchIdx=" << origBatchIdx
|
||||
<< " beamHypIdx=" << beamHypIdx
|
||||
<< " prevBeamHypIdx=" << prevBeamHypIdx
|
||||
<< std::endl;
|
||||
|
||||
auto prevHyp = beam[prevBeamHypIdx];
|
||||
Word word;
|
||||
// If short list has been set, then wordIdx is an index into the short-listed word set,
|
||||
@ -94,7 +99,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
|
||||
// For factored decoding, the word is built over multiple decoding steps,
|
||||
// starting with the lemma, then adding factors one by one.
|
||||
if (factorGroup == 0) {
|
||||
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap(prevBeamHypIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
|
||||
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap(currentBatchIdx, prevBeamHypIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
|
||||
std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
|
||||
//LOG(info, "{} + {} ({}) -> {} -> {}",
|
||||
// factoredVocab->decode(prevHyp->tracebackWords()),
|
||||
@ -115,7 +120,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
|
||||
}
|
||||
}
|
||||
else if (shortlist)
|
||||
word = Word::fromWordIndex(shortlist->reverseMap(prevBeamHypIdx, wordIdx));
|
||||
word = Word::fromWordIndex(shortlist->reverseMap(currentBatchIdx, prevBeamHypIdx, wordIdx));
|
||||
else
|
||||
word = Word::fromWordIndex(wordIdx);
|
||||
|
||||
@ -308,7 +313,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
|
||||
suppressed.erase(std::remove_if(suppressed.begin(),
|
||||
suppressed.end(),
|
||||
[&](WordIndex i) {
|
||||
return shortlist->tryForwardMap(3343, i) == data::Shortlist::npos; // TODO beamIdx
|
||||
return shortlist->tryForwardMap(4545, 3343, i) == data::Shortlist::npos; // TODO beamIdx
|
||||
}),
|
||||
suppressed.end());
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user