filter once for shortlist

This commit is contained in:
Hieu Hoang 2021-06-04 13:39:03 -07:00
parent 7faebf77ca
commit 28e5e2260a
2 changed files with 9 additions and 2 deletions

View File

@ -19,7 +19,8 @@ const T* get(const void*& current, size_t num = 1) {
//////////////////////////////////////////////////////////////////////////////////////
Shortlist::Shortlist(const std::vector<WordIndex>& indices)
: indices_(indices) {}
: indices_(indices)
, done_(false) {}
Shortlist::~Shortlist() {}
@ -34,6 +35,10 @@ WordIndex Shortlist::tryForwardMap(int , int , WordIndex wIdx) const {
}
void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
if (done_) {
return;
}
//if (indicesExpr_) return;
int currBeamSize = input->shape()[0];
int batchSize = input->shape()[2];
@ -50,6 +55,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
Expr indicesExprBC = getIndicesExpr(batchSize, currBeamSize);
broadcast(weights, isLegacyUntransposedW, b, lemmaEt, indicesExprBC, k);
done_ = true;
}
Expr Shortlist::getIndicesExpr(int batchSize, int beamSize) const {

View File

@ -29,7 +29,8 @@ protected:
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
Expr cachedShortb_; // these match the current value of shortlist_
Expr cachedShortLemmaEt_;
bool done_;
virtual void broadcast(Expr weights,
bool isLegacyUntransposedW,
Expr b,