pass shortlist regression tests

This commit is contained in:
Hieu Hoang 2021-06-28 21:26:02 -07:00
parent cd292d3b32
commit 24c644bae0

View File

@ -68,13 +68,13 @@ void Shortlist::createCachedTensors(Expr weights,
cachedShortWt_ = reshape(cachedShortWt_, {1, 1, cachedShortWt_->shape()[0], cachedShortWt_->shape()[1]});
if (b) {
ABORT("Bias not yet tested");
cachedShortb_ = index_select(b, -1, indicesExpr_);
cachedShortb_ = reshape(cachedShortb_, {1, k, 1, cachedShortb_->shape()[1]}); // not tested
}
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_);
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k});
if (lemmaEt) {
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_);
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k});
}
}
///////////////////////////////////////////////////////////////////////////////////