mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
filter once for shortlist
This commit is contained in:
parent
7faebf77ca
commit
28e5e2260a
@ -19,7 +19,8 @@ const T* get(const void*& current, size_t num = 1) {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////
|
||||
Shortlist::Shortlist(const std::vector<WordIndex>& indices)
|
||||
: indices_(indices) {}
|
||||
: indices_(indices)
|
||||
, done_(false) {}
|
||||
|
||||
Shortlist::~Shortlist() {}
|
||||
|
||||
@ -34,6 +35,10 @@ WordIndex Shortlist::tryForwardMap(int , int , WordIndex wIdx) const {
|
||||
}
|
||||
|
||||
void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
|
||||
if (done_) {
|
||||
return;
|
||||
}
|
||||
|
||||
//if (indicesExpr_) return;
|
||||
int currBeamSize = input->shape()[0];
|
||||
int batchSize = input->shape()[2];
|
||||
@ -50,6 +55,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
|
||||
|
||||
Expr indicesExprBC = getIndicesExpr(batchSize, currBeamSize);
|
||||
broadcast(weights, isLegacyUntransposedW, b, lemmaEt, indicesExprBC, k);
|
||||
done_ = true;
|
||||
}
|
||||
|
||||
Expr Shortlist::getIndicesExpr(int batchSize, int beamSize) const {
|
||||
|
@ -29,7 +29,8 @@ protected:
|
||||
Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
|
||||
Expr cachedShortb_; // these match the current value of shortlist_
|
||||
Expr cachedShortLemmaEt_;
|
||||
|
||||
bool done_;
|
||||
|
||||
virtual void broadcast(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
|
Loading…
Reference in New Issue
Block a user