This commit is contained in:
Hieu Hoang 2021-06-09 22:36:34 +00:00
parent 1e3db86a94
commit 6f0f534a4a

View File

@ -129,13 +129,14 @@ LSHShortlist::LSHShortlist(int k, int nbits)
//#define BLAS_FOUND 1
WordIndex LSHShortlist::reverseMap(int batchIdx, int beamIdx, int idx) const {
//std::cerr << "\nbatchIdx=" << batchIdx << " beamIdx=" << beamIdx << " idx=" << idx << std::endl;
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
std::cerr << "\nbatchIdx=" << batchIdx << " beamIdx=" << beamIdx << " idx=" << idx << std::endl;
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
int currBatchSize = indicesExpr_->shape()[0];
int currBeamSize = indicesExpr_->shape()[1];
//std::cerr << "currBeamSize=" << currBeamSize << std::endl;
//std::cerr << "indices_=" << indices_.size() << std::endl;
std::cerr << "currBatchSize=" << currBatchSize << " currBeamSize=" << currBeamSize << std::endl;
std::cerr << "indices_=" << indices_.size() << std::endl;
idx = (k_ * currBeamSize) * batchIdx + k_ * beamIdx + idx;
//std::cerr << "idx=" << idx << std::endl;
std::cerr << "idx=" << idx << std::endl;
assert(idx < indices_.size());
return indices_[idx];
}
@ -236,14 +237,14 @@ void LSHShortlist::broadcast(Expr weights,
indicesExprBC = reshape(indicesExprBC, {indicesExprBC->shape().elements()});
//std::cerr << "indicesExprBC.2=" << indicesExprBC->shape() << std::endl;
std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
std::cerr << "weights=" << weights->shape() << std::endl;
//std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
//std::cerr << "weights=" << weights->shape() << std::endl;
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExprBC);
std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
//std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
cachedShortWt_ = reshape(cachedShortWt_, {batchSize, currBeamSize, k, cachedShortWt_->shape()[1]});
std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;
//std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;
cachedShortWt_ = transpose(cachedShortWt_, {1, 0, 2, 3});
std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl;
//std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl;
if (b) {
ABORT("Bias not yet tested");