mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
remaining comments
This commit is contained in:
parent
8bfa6a44e3
commit
9772aa293f
@ -275,7 +275,7 @@ void FactoredVocab::constructGroupInfoFromFactorVocab() {
|
||||
groupCounts[g]++;
|
||||
}
|
||||
|
||||
// required by LSH shortlist
|
||||
// required by LSH shortlist. Factored segmenter encodes the number of lemmas in the first factor group, this corresponds to actual surface forms
|
||||
lemmaSize_ = groupCounts[0];
|
||||
|
||||
for (size_t g = 0; g < numGroups; g++) { // detect non-overlapping groups
|
||||
|
@ -19,8 +19,8 @@ const T* get(const void*& current, size_t num = 1) {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////
|
||||
Shortlist::Shortlist(const std::vector<WordIndex>& indices)
|
||||
: indices_(indices)
|
||||
, done_(false) {}
|
||||
: indices_(indices),
|
||||
initialized_(false) {}
|
||||
|
||||
Shortlist::~Shortlist() {}
|
||||
|
||||
@ -35,7 +35,7 @@ WordIndex Shortlist::tryForwardMap(WordIndex wIdx) const {
|
||||
}
|
||||
|
||||
void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
|
||||
if (done_) {
|
||||
if (initialized_) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -49,7 +49,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
|
||||
|
||||
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
|
||||
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k);
|
||||
done_ = true;
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
Expr Shortlist::getIndicesExpr() const {
|
||||
|
@ -29,13 +29,13 @@ protected:
|
||||
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
|
||||
Expr cachedShortb_; // these match the current value of shortlist_
|
||||
Expr cachedShortLemmaEt_;
|
||||
bool done_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch
|
||||
bool initialized_; // used by batch-level shortlist. Only initialize with 1st call then skip all subsequent calls for same batch
|
||||
|
||||
void createCachedTensors(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
int k);
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
int k);
|
||||
public:
|
||||
static constexpr WordIndex npos{std::numeric_limits<WordIndex>::max()}; // used to identify invalid shortlist entries similar to std::string::npos
|
||||
|
||||
@ -77,10 +77,10 @@ private:
|
||||
static std::mutex mutex_;
|
||||
|
||||
void createCachedTensors(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
int k);
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
int k);
|
||||
|
||||
public:
|
||||
LSHShortlist(int k, int nbits, size_t lemmaSize);
|
||||
|
@ -478,6 +478,10 @@ Expr bdot(Expr a,
|
||||
bool transB = false,
|
||||
float scalar = 1.f);
|
||||
|
||||
/**
|
||||
* bdot_legacy is an old implemetation of bdot without correct broadcasting on the batch dimensions,
|
||||
* to be removed once the behavior can be correctly replicated with normal bdot on 5 dimensions.
|
||||
*/
|
||||
Expr bdot_legacy(Expr a,
|
||||
Expr b,
|
||||
bool transA = false,
|
||||
|
Loading…
Reference in New Issue
Block a user