get lemma size from vocab class

This commit is contained in:
Hieu Hoang 2021-06-15 16:54:17 -07:00
parent 8c04f66474
commit 488a532bdf
8 changed files with 30 additions and 14 deletions

View File

@ -244,6 +244,10 @@ void FactoredVocab::rCompleteVocab(std::vector<size_t>& factorIndices, size_t g)
}
}
size_t FactoredVocab::lemmaSize() const {
return lemmaSize_;
}
void FactoredVocab::constructGroupInfoFromFactorVocab() {
// form groups
size_t numGroups = groupPrefixes_.size();
@ -270,6 +274,7 @@ void FactoredVocab::constructGroupInfoFromFactorVocab() {
groupRanges_[g].second = u + 1;
groupCounts[g]++;
}
lemmaSize_ = groupCounts[0];
for (size_t g = 0; g < numGroups; g++) { // detect non-overlapping groups
LOG(info, "[vocab] Factor group '{}' has {} members", groupPrefixes_[g], groupCounts[g]);
if (groupCounts[g] == 0) { // factor group is unused --@TODO: once this is not hard-coded, this is an error condition

View File

@ -46,6 +46,7 @@ public:
// factor-specific. These methods are consumed by Output and Embedding.
size_t factorVocabSize() const { return factorVocab_.size(); } // total number of factors across all types
size_t virtualVocabSize() const { return factorShape_.elements<size_t>(); } // valid WordIndex range (representing all factor combinations including gaps); virtual and huge
virtual size_t lemmaSize() const override;
CSRData csr_rows(const Words& words) const; // sparse matrix for summing up factors from the concatenated embedding matrix for each word
@ -116,6 +117,7 @@ private:
Word eosId_{};
Word unkId_{};
WordLUT vocab_;
size_t lemmaSize_;
// factors
char factorSeparator_ = '|'; // separator symbol for parsing factored words

View File

@ -88,9 +88,9 @@ void Shortlist::createCachedTensors(Expr weights,
///////////////////////////////////////////////////////////////////////////////////
Ptr<faiss::IndexLSH> LSHShortlist::index_;
LSHShortlist::LSHShortlist(int k, int nbits)
LSHShortlist::LSHShortlist(int k, int nbits, size_t lemmaSize)
: Shortlist(std::vector<WordIndex>())
, k_(k), nbits_(nbits) {
, k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
//std::cerr << "LSHShortlist" << std::endl;
/*
for (int i = 0; i < k_; ++i) {
@ -149,9 +149,8 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
index_.reset(new faiss::IndexLSH(dim, nbits_,
/*rotate=*/dim != nbits_,
/*train_thesholds*/false));
int vRows = 32121; //47960; //values->shape().elements() / dim;
index_->train(vRows, values->val()->data<float>());
index_->add( vRows, values->val()->data<float>());
index_->train(lemmaSize_, values->val()->data<float>());
index_->add( lemmaSize_, values->val()->data<float>());
}
int qRows = query->shape().elements() / dim;
@ -220,13 +219,13 @@ void LSHShortlist::createCachedTensors(Expr weights,
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {1, 2, 0, 3});
}
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits)
: k_(k), nbits_(nbits) {
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits, size_t lemmaSize)
: k_(k), nbits_(nbits), lemmaSize_(lemmaSize) {
//std::cerr << "LSHShortlistGenerator" << std::endl;
}
Ptr<Shortlist> LSHShortlistGenerator::generate(Ptr<data::CorpusBatch> batch) const {
return New<LSHShortlist>(k_, nbits_);
return New<LSHShortlist>(k_, nbits_, lemmaSize_);
}
//////////////////////////////////////////////////////////////////////////////////////
@ -359,7 +358,8 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
size_t trgIdx,
bool shared) {
if (lshOpts.size() == 2) {
return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1]);
size_t lemmaSize = trgVocab->lemmaSize();
return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1], lemmaSize);
}
else {
std::vector<std::string> vals = options->get<std::vector<std::string>>("shortlist");

View File

@ -70,7 +70,7 @@ class LSHShortlist: public Shortlist {
private:
int k_;
int nbits_;
size_t lemmaSize_;
static Ptr<faiss::IndexLSH> index_;
void createCachedTensors(Expr weights,
@ -80,7 +80,7 @@ private:
int k);
public:
LSHShortlist(int k, int nbits);
LSHShortlist(int k, int nbits, size_t lemmaSize);
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override;
@ -93,9 +93,9 @@ class LSHShortlistGenerator : public ShortlistGenerator {
private:
int k_;
int nbits_;
size_t lemmaSize_;
public:
LSHShortlistGenerator(int k, int nbits);
LSHShortlistGenerator(int k, int nbits, size_t lemmaSize);
Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override;
};

View File

@ -129,6 +129,10 @@ std::string Vocab::surfaceForm(const Words& sentence) const {
// number of vocabulary items
size_t Vocab::size() const { return vImpl_->size(); }
size_t Vocab::lemmaSize() const {
return vImpl_->lemmaSize();
}
// number of vocabulary items
std::string Vocab::type() const { return vImpl_->type(); }

View File

@ -61,6 +61,9 @@ public:
// number of vocabulary items
size_t size() const;
// number of vocabulary items
size_t lemmaSize() const;
// number of vocabulary items
std::string type() const;

View File

@ -39,6 +39,7 @@ public:
virtual const std::string& operator[](Word id) const = 0;
virtual size_t size() const = 0;
virtual size_t lemmaSize() const { return size(); }
virtual std::string type() const = 0;
virtual Word getEosId() const = 0;

View File

@ -63,8 +63,9 @@ public:
auto srcVocab = corpus_->getVocabs()[0];
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist"))
if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) {
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabs.front() == vocabs.back());
}
auto devices = Config::getDevices(options_);
numDevices_ = devices.size();