don't transpose lastIndices. Works for lsh & sl

This commit is contained in:
Hieu Hoang 2021-06-11 22:55:26 +00:00
parent 49998217d9
commit 700dc7fdd1

View File

@ -54,7 +54,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
Expr Shortlist::getIndicesExpr(int batchSize, int beamSize) const {
int k = indicesExpr_->shape()[0];
Expr ones = indicesExpr_->graph()->constant({batchSize, beamSize, 1}, inits::ones(), Type::float32);
Expr ones = indicesExpr_->graph()->constant({beamSize, batchSize, 1}, inits::ones(), Type::float32);
Expr tmp = reshape(indicesExpr_, {1, k});
tmp = cast(tmp, Type::float32);