don't manually broadcast lemma

This commit is contained in:
Hieu Hoang 2021-06-09 20:57:02 +00:00
parent 79dbde7efc
commit 4b9082bc39

View File

@ -114,7 +114,7 @@ void Shortlist::broadcast(Expr weights,
std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
std::cerr << "weights=" << weights->shape() << std::endl;
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExpr_);
std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
//std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
cachedShortWt_ = reshape(cachedShortWt_, {1, 1, cachedShortWt_->shape()[0], cachedShortWt_->shape()[1]});
if (b) {
@ -123,12 +123,11 @@ void Shortlist::broadcast(Expr weights,
cachedShortb_ = reshape(cachedShortb_, {1, k, 1, cachedShortb_->shape()[1]}); // not tested
}
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprBC);
//std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl;
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {cachedShortLemmaEt_->shape()[0], batchSize, currBeamSize, k});
//std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl;
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {2, 1, 0, 3});
//std::cerr << "cachedShortLemmaEt.3_=" << cachedShortLemmaEt_->shape() << std::endl;
std::cerr << "lemmaEt.1_=" << lemmaEt->shape() << std::endl;
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_);
std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl;
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k});
std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl;
}
///////////////////////////////////////////////////////////////////////////////////