changes for review

This commit is contained in:
Hieu Hoang 2021-06-18 10:18:31 -07:00
parent a332e550a5
commit cd292d3b32
9 changed files with 19 additions and 35 deletions

View File

@ -63,7 +63,7 @@ std::string findReplace(const std::string& in, const std::string& what, const st
double parseDouble(std::string s);
double parseNumber(std::string s);
// prints vector values with a custom label.
template<class T>
void Debug(const T *arr, size_t size, const std::string &str) {
std::cerr << str << ":" << size << ": ";

View File

@ -274,7 +274,10 @@ void FactoredVocab::constructGroupInfoFromFactorVocab() {
groupRanges_[g].second = u + 1;
groupCounts[g]++;
}
// required by LSH shortlist
lemmaSize_ = groupCounts[0];
for (size_t g = 0; g < numGroups; g++) { // detect non-overlapping groups
LOG(info, "[vocab] Factor group '{}' has {} members", groupPrefixes_[g], groupCounts[g]);
if (groupCounts[g] == 0) { // factor group is unused --@TODO: once this is not hard-coded, this is an error condition

View File

@ -24,9 +24,9 @@ Shortlist::Shortlist(const std::vector<WordIndex>& indices)
Shortlist::~Shortlist() {}
WordIndex Shortlist::reverseMap(int , int , int idx) const { return indices_[idx]; }
WordIndex Shortlist::reverseMap(int /*beamIdx*/, int /*batchIdx*/, int idx) const { return indices_[idx]; }
WordIndex Shortlist::tryForwardMap(int , int , WordIndex wIdx) const {
WordIndex Shortlist::tryForwardMap(WordIndex wIdx) const {
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
return (int)std::distance(indices_.begin(), first); // return coordinate if found
@ -83,15 +83,8 @@ Ptr<faiss::IndexLSH> LSHShortlist::index_;
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize)
: Shortlist(std::vector<WordIndex>())
, k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
/*
for (int i = 0; i < k_; ++i) {
indices_.push_back(i);
}
*/
}
//#define BLAS_FOUND 1
WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
//int currBeamSize = indicesExpr_->shape()[0];
int currBatchSize = indicesExpr_->shape()[1];
@ -100,15 +93,6 @@ WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
return indices_[idx];
}
WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const {
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
bool found = first != indices_.end();
if(found && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
return (int)std::distance(indices_.begin(), first); // return coordinate if found
else
return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
}
Expr LSHShortlist::getIndicesExpr() const {
return indicesExpr_;
}
@ -128,7 +112,6 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
int dim = values->shape()[-1];
if(!index_) {
//std::cerr << "build lsh index" << std::endl;
LOG(info, "Building LSH index for vector dim {} and with hash size {} bits", dim, nbits_);
index_.reset(new faiss::IndexLSH(dim, nbits_,
/*rotate=*/dim != nbits_,
@ -199,7 +182,6 @@ void LSHShortlist::createCachedTensors(Expr weights,
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits, size_t lemmaSize)
: k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
//std::cerr << "LSHShortlistGenerator" << std::endl;
}
Ptr<Shortlist> LSHShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {

View File

@ -29,7 +29,7 @@ protected:
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
bool done_;
bool done_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch
void createCachedTensors(Expr weights,
bool isLegacyUntransposedW,
@ -43,7 +43,7 @@ public:
virtual ~Shortlist();
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const;
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const;
virtual WordIndex tryForwardMap(WordIndex wIdx) const;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
virtual Expr getIndicesExpr() const;
@ -66,12 +66,14 @@ public:
};
///////////////////////////////////////////////////////////////////////////////////
// implements SLIDE for faster inference.
// https://arxiv.org/pdf/1903.03129.pdf
class LSHShortlist: public Shortlist {
private:
int k_;
int nbits_;
size_t lemmaSize_;
static Ptr<faiss::IndexLSH> index_;
int k_; // number of candidates returned from each input
int nbits_; // length of hash
size_t lemmaSize_; // vocab size
static Ptr<faiss::IndexLSH> index_; // LSH index to store all possible candidates
void createCachedTensors(Expr weights,
bool isLegacyUntransposedW,
@ -82,7 +84,6 @@ private:
public:
LSHShortlist(int k, int nbits, size_t lemmaSize);
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
virtual Expr getIndicesExpr() const override;

View File

@ -133,7 +133,7 @@ size_t Vocab::lemmaSize() const {
return vImpl_->lemmaSize();
}
// number of vocabulary items
// type of vocabulary items
std::string Vocab::type() const { return vImpl_->type(); }
// return EOS symbol id

View File

@ -61,7 +61,7 @@ public:
// number of vocabulary items
size_t size() const;
// number of vocabulary items
// number of lemma items. Same as size() except in factored models
size_t lemmaSize() const;
// number of vocabulary items

View File

@ -247,8 +247,6 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<
std::vector<float> Logits::getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr)
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
//std::cerr << "indicesExpr=" << indicesExpr->shape() << std::endl;
//int batchSize
int batchSize = indicesExpr->shape()[0];
int currBeamSize = indicesExpr->shape()[1];
int numHypos = batchSize * currBeamSize;

View File

@ -56,12 +56,12 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
lazyConstruct(input->shape()[-1]);
auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
/*
std::cerr << "affineOrDot.x=" << x->shape() << std::endl;
std::cerr << "affineOrDot.W=" << W->shape() << std::endl;
std::cerr << "affineOrDot.b=" << b->shape() << std::endl;
std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl;
*/
if(b)
return affine(x, W, b, transA, transB);
else

View File

@ -315,7 +315,7 @@ Histories BeamSearch::search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch>
suppressed.erase(std::remove_if(suppressed.begin(),
suppressed.end(),
[&](WordIndex i) {
return shortlist->tryForwardMap(4545, 3343, i) == data::Shortlist::npos; // TODO beamIdx
return shortlist->tryForwardMap(i) == data::Shortlist::npos; // TODO beamIdx
}),
suppressed.end());