lsh runs but crap output

This commit is contained in:
Hieu Hoang 2021-04-29 12:54:35 -07:00
parent 1784da0585
commit 947301a817

View File

@ -90,6 +90,7 @@ void Shortlist::broadcast(Expr weights,
Expr lemmaEt,
Expr indicesExprBC,
int k) {
std::cerr << "indicesExprBC.0=" << indicesExprBC->shape() << std::endl;
int batchSize = indicesExprBC->shape()[0];
int currBeamSize = indicesExprBC->shape()[1];
//int numHypos = batchSize * currBeamSize;
@ -155,17 +156,23 @@ WordIndex LSHShortlist::tryForwardMap(size_t beamIdx, WordIndex wIdx) const {
}
Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const {
assert(indicesExpr_->shape()[0] == currBeamSize);
std::cerr << "batchSize=" << batchSize << " currBeamSize=" << currBeamSize << std::endl;
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << " " << indicesExpr_->val() << std::endl;
assert(indicesExpr_->shape()[0] == batchSize);
assert(indicesExpr_->shape()[1] == currBeamSize);
return indicesExpr_;
}
void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
#if BLAS_FOUND
int currBeamSize = input->shape()[0];
ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
"LSH index (--output-approx-knn) currently not implemented for GPU");
auto forward = [this, currBeamSize](Expr out, const std::vector<Expr>& inputs) {
int currBeamSize = input->shape()[0];
int batchSize = input->shape()[2];
int numHypos = currBeamSize * batchSize;
auto forward = [this, numHypos](Expr out, const std::vector<Expr>& inputs) {
auto query = inputs[0];
auto values = inputs[1];
int dim = values->shape()[-1];
@ -193,16 +200,15 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
indices_.push_back(id);
}
for (size_t beamIdx = 0; beamIdx < currBeamSize; ++beamIdx) {
size_t startIdx = k_ * beamIdx;
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
size_t startIdx = k_ * hypoIdx;
size_t endIdx = startIdx + k_;
std::sort(indices_.begin() + startIdx, indices_.begin() + endIdx);
}
out->val()->set(indices_);
//std::cerr << "out=" << out->shape() << " " << out->val() << std::endl;
};
Shape kShape({currBeamSize, k_});
Shape kShape({batchSize, currBeamSize, k_});
//std::cerr << "kShape=" << kShape << std::endl;
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);