mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
lsh runs but crap output
This commit is contained in:
parent
1784da0585
commit
947301a817
@ -90,6 +90,7 @@ void Shortlist::broadcast(Expr weights,
|
||||
Expr lemmaEt,
|
||||
Expr indicesExprBC,
|
||||
int k) {
|
||||
std::cerr << "indicesExprBC.0=" << indicesExprBC->shape() << std::endl;
|
||||
int batchSize = indicesExprBC->shape()[0];
|
||||
int currBeamSize = indicesExprBC->shape()[1];
|
||||
//int numHypos = batchSize * currBeamSize;
|
||||
@ -155,17 +156,23 @@ WordIndex LSHShortlist::tryForwardMap(size_t beamIdx, WordIndex wIdx) const {
|
||||
}
|
||||
|
||||
Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const {
|
||||
assert(indicesExpr_->shape()[0] == currBeamSize);
|
||||
std::cerr << "batchSize=" << batchSize << " currBeamSize=" << currBeamSize << std::endl;
|
||||
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << " " << indicesExpr_->val() << std::endl;
|
||||
assert(indicesExpr_->shape()[0] == batchSize);
|
||||
assert(indicesExpr_->shape()[1] == currBeamSize);
|
||||
return indicesExpr_;
|
||||
}
|
||||
|
||||
void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
|
||||
#if BLAS_FOUND
|
||||
int currBeamSize = input->shape()[0];
|
||||
ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
|
||||
"LSH index (--output-approx-knn) currently not implemented for GPU");
|
||||
|
||||
auto forward = [this, currBeamSize](Expr out, const std::vector<Expr>& inputs) {
|
||||
int currBeamSize = input->shape()[0];
|
||||
int batchSize = input->shape()[2];
|
||||
int numHypos = currBeamSize * batchSize;
|
||||
|
||||
auto forward = [this, numHypos](Expr out, const std::vector<Expr>& inputs) {
|
||||
auto query = inputs[0];
|
||||
auto values = inputs[1];
|
||||
int dim = values->shape()[-1];
|
||||
@ -193,16 +200,15 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
|
||||
indices_.push_back(id);
|
||||
}
|
||||
|
||||
for (size_t beamIdx = 0; beamIdx < currBeamSize; ++beamIdx) {
|
||||
size_t startIdx = k_ * beamIdx;
|
||||
for (size_t hypoIdx = 0; hypoIdx < numHypos; ++hypoIdx) {
|
||||
size_t startIdx = k_ * hypoIdx;
|
||||
size_t endIdx = startIdx + k_;
|
||||
std::sort(indices_.begin() + startIdx, indices_.begin() + endIdx);
|
||||
}
|
||||
out->val()->set(indices_);
|
||||
//std::cerr << "out=" << out->shape() << " " << out->val() << std::endl;
|
||||
};
|
||||
|
||||
Shape kShape({currBeamSize, k_});
|
||||
Shape kShape({batchSize, currBeamSize, k_});
|
||||
//std::cerr << "kShape=" << kShape << std::endl;
|
||||
|
||||
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
|
||||
|
Loading…
Reference in New Issue
Block a user