start lsh

This commit is contained in:
Hieu Hoang 2021-04-29 12:22:45 -07:00
parent e518fc9666
commit 1784da0585
4 changed files with 165 additions and 16 deletions

View File

@ -2,6 +2,10 @@
#include "microsoft/shortlist/utils/ParameterTree.h"
#include "marian.h"
#if BLAS_FOUND
#include "3rd_party/faiss/IndexLSH.h"
#endif
namespace marian {
namespace data {
@ -18,9 +22,9 @@ Shortlist::Shortlist(const std::vector<WordIndex>& indices)
: indices_(indices) {}
const std::vector<WordIndex>& Shortlist::indices() const { return indices_; }
WordIndex Shortlist::reverseMap(int idx) { return indices_[idx]; }
WordIndex Shortlist::reverseMap(size_t beamIdx, int idx) const { return indices_[idx]; }
WordIndex Shortlist::tryForwardMap(WordIndex wIdx) {
WordIndex Shortlist::tryForwardMap(size_t beamIdx, WordIndex wIdx) const {
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
return (int)std::distance(indices_.begin(), first); // return coordinate if found
@ -117,6 +121,110 @@ void Shortlist::broadcast(Expr weights,
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {2, 0, 1, 3});
//std::cerr << "cachedShortLemmaEt.3_=" << cachedShortLemmaEt_->shape() << std::endl;
}
///////////////////////////////////////////////////////////////////////////////////
Ptr<faiss::IndexLSH> LSHShortlist::index_;
LSHShortlist::LSHShortlist(int k, int nbits)
: Shortlist(std::vector<WordIndex>())
, k_(k), nbits_(nbits) {
//std::cerr << "LSHShortlist" << std::endl;
/*
for (int i = 0; i < k_; ++i) {
indices_.push_back(i);
}
*/
}
#define BLAS_FOUND 1
WordIndex LSHShortlist::reverseMap(size_t beamIdx, int idx) const {
idx = k_ * beamIdx + idx;
assert(idx < indices_.size());
return indices_[idx];
}
WordIndex LSHShortlist::tryForwardMap(size_t beamIdx, WordIndex wIdx) const {
//utils::Debug(indices_, "LSHShortlist::tryForwardMap indices_");
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
bool found = first != indices_.end();
if(found && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
return (int)std::distance(indices_.begin(), first); // return coordinate if found
else
return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
}
Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const {
assert(indicesExpr_->shape()[0] == currBeamSize);
return indicesExpr_;
}
void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
#if BLAS_FOUND
int currBeamSize = input->shape()[0];
ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
"LSH index (--output-approx-knn) currently not implemented for GPU");
auto forward = [this, currBeamSize](Expr out, const std::vector<Expr>& inputs) {
auto query = inputs[0];
auto values = inputs[1];
int dim = values->shape()[-1];
if(!index_) {
//std::cerr << "build lsh index" << std::endl;
LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
index_.reset(new faiss::IndexLSH(dim, nbits_,
/*rotate=*/dim != nbits_,
/*train_thesholds*/false));
int vRows = 32121; //47960; //values->shape().elements() / dim;
index_->train(vRows, values->val()->data<float>());
index_->add( vRows, values->val()->data<float>());
}
int qRows = query->shape().elements() / dim;
std::vector<float> distances(qRows * k_);
std::vector<faiss::Index::idx_t> ids(qRows * k_);
index_->search(qRows, query->val()->data<float>(), k_,
distances.data(), ids.data());
indices_.clear();
for(auto id : ids) {
indices_.push_back(id);
}
for (size_t beamIdx = 0; beamIdx < currBeamSize; ++beamIdx) {
size_t startIdx = k_ * beamIdx;
size_t endIdx = startIdx + k_;
std::sort(indices_.begin() + startIdx, indices_.begin() + endIdx);
}
out->val()->set(indices_);
//std::cerr << "out=" << out->shape() << " " << out->val() << std::endl;
};
Shape kShape({currBeamSize, k_});
//std::cerr << "kShape=" << kShape << std::endl;
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
broadcast(weights, isLegacyUntransposedW, b, lemmaEt, indicesExpr_, k_);
#else
query; values;
ABORT("LSH output layer requires a CPU BLAS library");
#endif
}
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits)
: k_(k), nbits_(nbits) {
//std::cerr << "LSHShortlistGenerator" << std::endl;
}
Ptr<Shortlist> LSHShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
return New<LSHShortlist>(k_, nbits_);
}
//////////////////////////////////////////////////////////////////////////////////////
QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
@ -242,16 +350,22 @@ Ptr<Shortlist> QuicksandShortlistGenerator::generate(Ptr<data::CorpusBatch> batc
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
const std::vector<int> &lshOpts,
size_t srcIdx,
size_t trgIdx,
bool shared) {
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
ABORT_IF(vals.empty(), "No path to shortlist given");
std::string fname = vals[0];
if(filesystem::Path(fname).extension().string() == ".bin") {
return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
} else {
return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
if (lshOpts.size() == 2) {
return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1]);
}
else {
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
ABORT_IF(vals.empty(), "No path to shortlist given");
std::string fname = vals[0];
if(filesystem::Path(fname).extension().string() == ".bin") {
return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
} else {
return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
}
}
}

View File

@ -15,6 +15,10 @@
#include <algorithm>
#include <limits>
namespace faiss {
struct IndexLSH;
}
namespace marian {
namespace data {
@ -38,8 +42,8 @@ public:
Shortlist(const std::vector<WordIndex>& indices);
const std::vector<WordIndex>& indices() const;
WordIndex reverseMap(int idx);
WordIndex tryForwardMap(WordIndex wIdx);
virtual WordIndex reverseMap(size_t beamIdx, int idx) const;
virtual WordIndex tryForwardMap(size_t beamIdx, WordIndex wIdx) const;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
virtual Expr getIndicesExpr(int batchSize, int currBeamSize) const;
@ -61,6 +65,35 @@ public:
}
};
///////////////////////////////////////////////////////////////////////////////////
class LSHShortlist: public Shortlist {
private:
int k_;
int nbits_;
static Ptr<faiss::IndexLSH> index_;
public:
LSHShortlist(int k, int nbits);
virtual WordIndex reverseMap(size_t beamIdx, int idx) const override;
virtual WordIndex tryForwardMap(size_t beamIdx, WordIndex wIdx) const override;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
virtual Expr getIndicesExpr(int batchSize,int currBeamSize) const override;
};
class LSHShortlistGenerator : public ShortlistGenerator {
private:
int k_;
int nbits_;
public:
LSHShortlistGenerator(int k, int nbits);
Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
};
///////////////////////////////////////////////////////////////////////////////////
// Intended for use during training in the future, currently disabled
#if 0
@ -345,6 +378,7 @@ unless the extension is *.bin for which the Microsoft legacy binary shortlist is
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
Ptr<const Vocab> srcVocab,
Ptr<const Vocab> trgVocab,
const std::vector<int> &lshOpts,
size_t srcIdx = 0,
size_t trgIdx = 1,
bool shared = false);

View File

@ -94,7 +94,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
// For factored decoding, the word is built over multiple decoding steps,
// starting with the lemma, then adding factors one by one.
if (factorGroup == 0) {
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap(wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap(prevBeamHypIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
//LOG(info, "{} + {} ({}) -> {} -> {}",
// factoredVocab->decode(prevHyp->tracebackWords()),
@ -115,7 +115,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
}
}
else if (shortlist)
word = Word::fromWordIndex(shortlist->reverseMap(wordIdx));
word = Word::fromWordIndex(shortlist->reverseMap(prevBeamHypIdx, wordIdx));
else
word = Word::fromWordIndex(wordIdx);
@ -308,7 +308,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
suppressed.erase(std::remove_if(suppressed.begin(),
suppressed.end(),
[&](WordIndex i) {
return shortlist->tryForwardMap(i) == data::Shortlist::npos;
return shortlist->tryForwardMap(3343, i) == data::Shortlist::npos; // TODO beamIdx
}),
suppressed.end());

View File

@ -62,8 +62,9 @@ public:
trgVocab_->load(vocabs.back());
auto srcVocab = corpus_->getVocabs()[0];
if(options_->hasAndNotEmpty("shortlist"))
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist"))
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabs.front() == vocabs.back());
auto devices = Config::getDevices(options_);
numDevices_ = devices.size();