mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
start lsh
This commit is contained in:
parent
e518fc9666
commit
1784da0585
@ -2,6 +2,10 @@
|
||||
#include "microsoft/shortlist/utils/ParameterTree.h"
|
||||
#include "marian.h"
|
||||
|
||||
#if BLAS_FOUND
|
||||
#include "3rd_party/faiss/IndexLSH.h"
|
||||
#endif
|
||||
|
||||
namespace marian {
|
||||
namespace data {
|
||||
|
||||
@ -18,9 +22,9 @@ Shortlist::Shortlist(const std::vector<WordIndex>& indices)
|
||||
: indices_(indices) {}
|
||||
|
||||
const std::vector<WordIndex>& Shortlist::indices() const { return indices_; }
|
||||
WordIndex Shortlist::reverseMap(int idx) { return indices_[idx]; }
|
||||
WordIndex Shortlist::reverseMap(size_t beamIdx, int idx) const { return indices_[idx]; }
|
||||
|
||||
WordIndex Shortlist::tryForwardMap(WordIndex wIdx) {
|
||||
WordIndex Shortlist::tryForwardMap(size_t beamIdx, WordIndex wIdx) const {
|
||||
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
|
||||
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
|
||||
return (int)std::distance(indices_.begin(), first); // return coordinate if found
|
||||
@ -117,6 +121,110 @@ void Shortlist::broadcast(Expr weights,
|
||||
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {2, 0, 1, 3});
|
||||
//std::cerr << "cachedShortLemmaEt.3_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
Ptr<faiss::IndexLSH> LSHShortlist::index_;
|
||||
|
||||
LSHShortlist::LSHShortlist(int k, int nbits)
|
||||
: Shortlist(std::vector<WordIndex>())
|
||||
, k_(k), nbits_(nbits) {
|
||||
//std::cerr << "LSHShortlist" << std::endl;
|
||||
/*
|
||||
for (int i = 0; i < k_; ++i) {
|
||||
indices_.push_back(i);
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
#define BLAS_FOUND 1
|
||||
|
||||
WordIndex LSHShortlist::reverseMap(size_t beamIdx, int idx) const {
|
||||
idx = k_ * beamIdx + idx;
|
||||
assert(idx < indices_.size());
|
||||
return indices_[idx];
|
||||
}
|
||||
|
||||
WordIndex LSHShortlist::tryForwardMap(size_t beamIdx, WordIndex wIdx) const {
|
||||
//utils::Debug(indices_, "LSHShortlist::tryForwardMap indices_");
|
||||
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
|
||||
bool found = first != indices_.end();
|
||||
if(found && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
|
||||
return (int)std::distance(indices_.begin(), first); // return coordinate if found
|
||||
else
|
||||
return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
|
||||
}
|
||||
|
||||
Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const {
|
||||
assert(indicesExpr_->shape()[0] == currBeamSize);
|
||||
return indicesExpr_;
|
||||
}
|
||||
|
||||
void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
|
||||
#if BLAS_FOUND
|
||||
int currBeamSize = input->shape()[0];
|
||||
ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
|
||||
"LSH index (--output-approx-knn) currently not implemented for GPU");
|
||||
|
||||
auto forward = [this, currBeamSize](Expr out, const std::vector<Expr>& inputs) {
|
||||
auto query = inputs[0];
|
||||
auto values = inputs[1];
|
||||
int dim = values->shape()[-1];
|
||||
|
||||
if(!index_) {
|
||||
//std::cerr << "build lsh index" << std::endl;
|
||||
LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
|
||||
index_.reset(new faiss::IndexLSH(dim, nbits_,
|
||||
/*rotate=*/dim != nbits_,
|
||||
/*train_thesholds*/false));
|
||||
int vRows = 32121; //47960; //values->shape().elements() / dim;
|
||||
index_->train(vRows, values->val()->data<float>());
|
||||
index_->add( vRows, values->val()->data<float>());
|
||||
}
|
||||
|
||||
int qRows = query->shape().elements() / dim;
|
||||
std::vector<float> distances(qRows * k_);
|
||||
std::vector<faiss::Index::idx_t> ids(qRows * k_);
|
||||
|
||||
index_->search(qRows, query->val()->data<float>(), k_,
|
||||
distances.data(), ids.data());
|
||||
|
||||
indices_.clear();
|
||||
for(auto id : ids) {
|
||||
indices_.push_back(id);
|
||||
}
|
||||
|
||||
for (size_t beamIdx = 0; beamIdx < currBeamSize; ++beamIdx) {
|
||||
size_t startIdx = k_ * beamIdx;
|
||||
size_t endIdx = startIdx + k_;
|
||||
std::sort(indices_.begin() + startIdx, indices_.begin() + endIdx);
|
||||
}
|
||||
out->val()->set(indices_);
|
||||
//std::cerr << "out=" << out->shape() << " " << out->val() << std::endl;
|
||||
};
|
||||
|
||||
Shape kShape({currBeamSize, k_});
|
||||
//std::cerr << "kShape=" << kShape << std::endl;
|
||||
|
||||
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
|
||||
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
|
||||
|
||||
broadcast(weights, isLegacyUntransposedW, b, lemmaEt, indicesExpr_, k_);
|
||||
|
||||
#else
|
||||
query; values;
|
||||
ABORT("LSH output layer requires a CPU BLAS library");
|
||||
#endif
|
||||
}
|
||||
|
||||
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits)
|
||||
: k_(k), nbits_(nbits) {
|
||||
//std::cerr << "LSHShortlistGenerator" << std::endl;
|
||||
}
|
||||
|
||||
Ptr<Shortlist> LSHShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
|
||||
return New<LSHShortlist>(k_, nbits_);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////
|
||||
QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,
|
||||
Ptr<const Vocab> srcVocab,
|
||||
@ -242,16 +350,22 @@ Ptr<Shortlist> QuicksandShortlistGenerator::generate(Ptr<data::CorpusBatch> batc
|
||||
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
|
||||
Ptr<const Vocab> srcVocab,
|
||||
Ptr<const Vocab> trgVocab,
|
||||
const std::vector<int> &lshOpts,
|
||||
size_t srcIdx,
|
||||
size_t trgIdx,
|
||||
bool shared) {
|
||||
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
|
||||
ABORT_IF(vals.empty(), "No path to shortlist given");
|
||||
std::string fname = vals[0];
|
||||
if(filesystem::Path(fname).extension().string() == ".bin") {
|
||||
return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
|
||||
} else {
|
||||
return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
|
||||
if (lshOpts.size() == 2) {
|
||||
return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1]);
|
||||
}
|
||||
else {
|
||||
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");
|
||||
ABORT_IF(vals.empty(), "No path to shortlist given");
|
||||
std::string fname = vals[0];
|
||||
if(filesystem::Path(fname).extension().string() == ".bin") {
|
||||
return New<QuicksandShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
|
||||
} else {
|
||||
return New<LexicalShortlistGenerator>(options, srcVocab, trgVocab, srcIdx, trgIdx, shared);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,10 @@
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
namespace faiss {
|
||||
struct IndexLSH;
|
||||
}
|
||||
|
||||
namespace marian {
|
||||
namespace data {
|
||||
|
||||
@ -38,8 +42,8 @@ public:
|
||||
Shortlist(const std::vector<WordIndex>& indices);
|
||||
|
||||
const std::vector<WordIndex>& indices() const;
|
||||
WordIndex reverseMap(int idx);
|
||||
WordIndex tryForwardMap(WordIndex wIdx);
|
||||
virtual WordIndex reverseMap(size_t beamIdx, int idx) const;
|
||||
virtual WordIndex tryForwardMap(size_t beamIdx, WordIndex wIdx) const;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
|
||||
virtual Expr getIndicesExpr(int batchSize, int currBeamSize) const;
|
||||
@ -61,6 +65,35 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
class LSHShortlist: public Shortlist {
|
||||
private:
|
||||
int k_;
|
||||
int nbits_;
|
||||
|
||||
static Ptr<faiss::IndexLSH> index_;
|
||||
|
||||
public:
|
||||
LSHShortlist(int k, int nbits);
|
||||
virtual WordIndex reverseMap(size_t beamIdx, int idx) const override;
|
||||
virtual WordIndex tryForwardMap(size_t beamIdx, WordIndex wIdx) const override;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
|
||||
virtual Expr getIndicesExpr(int batchSize,int currBeamSize) const override;
|
||||
|
||||
};
|
||||
|
||||
class LSHShortlistGenerator : public ShortlistGenerator {
|
||||
private:
|
||||
int k_;
|
||||
int nbits_;
|
||||
|
||||
public:
|
||||
LSHShortlistGenerator(int k, int nbits);
|
||||
Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Intended for use during training in the future, currently disabled
|
||||
#if 0
|
||||
@ -345,6 +378,7 @@ unless the extension is *.bin for which the Microsoft legacy binary shortlist is
|
||||
Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
|
||||
Ptr<const Vocab> srcVocab,
|
||||
Ptr<const Vocab> trgVocab,
|
||||
const std::vector<int> &lshOpts,
|
||||
size_t srcIdx = 0,
|
||||
size_t trgIdx = 1,
|
||||
bool shared = false);
|
||||
|
@ -94,7 +94,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
|
||||
// For factored decoding, the word is built over multiple decoding steps,
|
||||
// starting with the lemma, then adding factors one by one.
|
||||
if (factorGroup == 0) {
|
||||
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap(wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
|
||||
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap(prevBeamHypIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
|
||||
std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
|
||||
//LOG(info, "{} + {} ({}) -> {} -> {}",
|
||||
// factoredVocab->decode(prevHyp->tracebackWords()),
|
||||
@ -115,7 +115,7 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
|
||||
}
|
||||
}
|
||||
else if (shortlist)
|
||||
word = Word::fromWordIndex(shortlist->reverseMap(wordIdx));
|
||||
word = Word::fromWordIndex(shortlist->reverseMap(prevBeamHypIdx, wordIdx));
|
||||
else
|
||||
word = Word::fromWordIndex(wordIdx);
|
||||
|
||||
@ -308,7 +308,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
|
||||
suppressed.erase(std::remove_if(suppressed.begin(),
|
||||
suppressed.end(),
|
||||
[&](WordIndex i) {
|
||||
return shortlist->tryForwardMap(i) == data::Shortlist::npos;
|
||||
return shortlist->tryForwardMap(3343, i) == data::Shortlist::npos; // TODO beamIdx
|
||||
}),
|
||||
suppressed.end());
|
||||
|
||||
|
@ -62,8 +62,9 @@ public:
|
||||
trgVocab_->load(vocabs.back());
|
||||
auto srcVocab = corpus_->getVocabs()[0];
|
||||
|
||||
if(options_->hasAndNotEmpty("shortlist"))
|
||||
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
|
||||
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
|
||||
if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist"))
|
||||
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabs.front() == vocabs.back());
|
||||
|
||||
auto devices = Config::getDevices(options_);
|
||||
numDevices_ = devices.size();
|
||||
|
Loading…
Reference in New Issue
Block a user