remaining comments

This commit is contained in:
Marcin Junczys-Dowmunt 2021-07-03 12:13:26 -07:00
parent 8bfa6a44e3
commit 9772aa293f
4 changed files with 18 additions and 14 deletions

View File

@ -275,7 +275,7 @@ void FactoredVocab::constructGroupInfoFromFactorVocab() {
groupCounts[g]++;
}
// required by LSH shortlist
// required by LSH shortlist. Factored segmenter encodes the number of lemmas in the first factor group, this corresponds to actual surface forms
lemmaSize_ = groupCounts[0];
for (size_t g = 0; g < numGroups; g++) { // detect non-overlapping groups

View File

@ -19,8 +19,8 @@ const T* get(const void*& current, size_t num = 1) {
//////////////////////////////////////////////////////////////////////////////////////
Shortlist::Shortlist(const std::vector<WordIndex>& indices)
: indices_(indices)
, done_(false) {}
: indices_(indices),
initialized_(false) {}
Shortlist::~Shortlist() {}
@ -35,7 +35,7 @@ WordIndex Shortlist::tryForwardMap(WordIndex wIdx) const {
}
void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
if (done_) {
if (initialized_) {
return;
}
@ -49,7 +49,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k);
done_ = true;
initialized_ = true;
}
Expr Shortlist::getIndicesExpr() const {

View File

@ -29,13 +29,13 @@ protected:
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
bool done_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch
bool initialized_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch
void createCachedTensors(Expr weights,
bool isLegacyUntransposedW,
Expr b,
Expr lemmaEt,
int k);
bool isLegacyUntransposedW,
Expr b,
Expr lemmaEt,
int k);
public:
static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos
@ -77,10 +77,10 @@ private:
static std::mutex mutex_;
void createCachedTensors(Expr weights,
bool isLegacyUntransposedW,
Expr b,
Expr lemmaEt,
int k);
bool isLegacyUntransposedW,
Expr b,
Expr lemmaEt,
int k);
public:
LSHShortlist(int k, int nbits, size_t lemmaSize);

View File

@ -478,6 +478,10 @@ Expr bdot(Expr a,
bool transB = false,
float scalar = 1.f);
/**
* bdot_legacy is an old implemetation of bdot without correct broadcasting on the batch dimensions,
* to be removed once the behavior can be correctly replicated with normal bdot on 5 dimensions.
*/
Expr bdot_legacy(Expr a,
Expr b,
bool transA = false,