mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
pass shortlist regression tests
This commit is contained in:
parent
cd292d3b32
commit
24c644bae0
@ -68,13 +68,13 @@ void Shortlist::createCachedTensors(Expr weights,
|
||||
cachedShortWt_ = reshape(cachedShortWt_, {1, 1, cachedShortWt_->shape()[0], cachedShortWt_->shape()[1]});
|
||||
|
||||
if (b) {
|
||||
ABORT("Bias not yet tested");
|
||||
cachedShortb_ = index_select(b, -1, indicesExpr_);
|
||||
cachedShortb_ = reshape(cachedShortb_, {1, k, 1, cachedShortb_->shape()[1]}); // not tested
|
||||
}
|
||||
|
||||
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_);
|
||||
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k});
|
||||
if (lemmaEt) {
|
||||
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_);
|
||||
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k});
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
|
Loading…
Reference in New Issue
Block a user