mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
Merge branch 'hihoan/lsh7' of vs-ssh.visualstudio.com:v3/machinetranslation/Marian/marian-dev into hihoan/lsh7
This commit is contained in:
commit
6981b21f4e
@ -52,7 +52,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
|
||||
done_ = true;
|
||||
}
|
||||
|
||||
Expr Shortlist::getIndicesExpr(int batchSize, int beamSize) const {
|
||||
Expr Shortlist::getIndicesExpr() const {
|
||||
int k = indicesExpr_->shape()[0];
|
||||
Expr out = reshape(indicesExpr_, {1, 1, k});
|
||||
return out;
|
||||
@ -63,13 +63,8 @@ void Shortlist::createCachedTensors(Expr weights,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
int k) {
|
||||
//std::cerr << "isLegacyUntransposedW=" << isLegacyUntransposedW << std::endl;
|
||||
ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");
|
||||
|
||||
//std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
|
||||
//std::cerr << "weights=" << weights->shape() << std::endl;
|
||||
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExpr_);
|
||||
//std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
|
||||
cachedShortWt_ = reshape(cachedShortWt_, {1, 1, cachedShortWt_->shape()[0], cachedShortWt_->shape()[1]});
|
||||
|
||||
if (b) {
|
||||
@ -78,11 +73,8 @@ void Shortlist::createCachedTensors(Expr weights,
|
||||
cachedShortb_ = reshape(cachedShortb_, {1, k, 1, cachedShortb_->shape()[1]}); // not tested
|
||||
}
|
||||
|
||||
//std::cerr << "lemmaEt.1_=" << lemmaEt->shape() << std::endl;
|
||||
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExpr_);
|
||||
//std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {1, 1, cachedShortLemmaEt_->shape()[0], k});
|
||||
//std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
@ -110,7 +102,6 @@ WordIndex LSHShortlist::reverseMap(int beamIdx, int batchIdx, int idx) const {
|
||||
}
|
||||
|
||||
WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const {
|
||||
//utils::Debug(indices_, "LSHShortlist::tryForwardMap indices_");
|
||||
auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
|
||||
bool found = first != indices_.end();
|
||||
if(found && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
|
||||
@ -119,16 +110,10 @@ WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const {
|
||||
return npos; // return npos if not found, @TODO: replace with std::optional once we switch to C++17?
|
||||
}
|
||||
|
||||
Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const {
|
||||
assert(indicesExpr_->shape()[0] == currBeamSize);
|
||||
assert(indicesExpr_->shape()[1] == batchSize);
|
||||
Expr LSHShortlist::getIndicesExpr() const {
|
||||
return indicesExpr_;
|
||||
//Expr ret = transpose(indicesExpr_, {1, 0, 2});
|
||||
//return ret;
|
||||
}
|
||||
|
||||
#define BLAS_FOUND 1
|
||||
|
||||
void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
|
||||
#if BLAS_FOUND
|
||||
ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
|
||||
@ -175,11 +160,7 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
|
||||
};
|
||||
|
||||
Shape kShape({currBeamSize, batchSize, k_});
|
||||
|
||||
//std::cerr << "input=" << input->shape() << std::endl;
|
||||
//std::cerr << "weights=" << weights->shape() << std::endl;
|
||||
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
|
||||
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
|
||||
|
||||
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k_);
|
||||
|
||||
@ -196,10 +177,6 @@ void LSHShortlist::createCachedTensors(Expr weights,
|
||||
int k) {
|
||||
int currBeamSize = indicesExpr_->shape()[0];
|
||||
int batchSize = indicesExpr_->shape()[1];
|
||||
//int numHypos = batchSize * currBeamSize;
|
||||
//std::cerr << "batchSize=" << batchSize << std::endl;
|
||||
//std::cerr << "currBeamSize=" << currBeamSize << std::endl;
|
||||
//std::cerr << "isLegacyUntransposedW=" << isLegacyUntransposedW << std::endl;
|
||||
ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");
|
||||
|
||||
Expr indicesExprFlatten = reshape(indicesExpr_, {indicesExpr_->shape().elements()});
|
||||
|
@ -46,7 +46,7 @@ public:
|
||||
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt);
|
||||
virtual Expr getIndicesExpr(int batchSize, int currBeamSize) const;
|
||||
virtual Expr getIndicesExpr() const;
|
||||
virtual Expr getCachedShortWt() const { return cachedShortWt_; }
|
||||
virtual Expr getCachedShortb() const { return cachedShortb_; }
|
||||
virtual Expr getCachedShortLemmaEt() const { return cachedShortLemmaEt_; }
|
||||
@ -85,7 +85,7 @@ public:
|
||||
virtual WordIndex tryForwardMap(int batchIdx, int beamIdx, WordIndex wIdx) const override;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
|
||||
virtual Expr getIndicesExpr(int batchSize,int currBeamSize) const override;
|
||||
virtual Expr getIndicesExpr() const override;
|
||||
|
||||
};
|
||||
|
||||
|
@ -117,28 +117,22 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
|
||||
out->val()->set(masks);
|
||||
};
|
||||
|
||||
int currBeamSize = sel->shape()[0];
|
||||
int batchSize = sel->shape()[2];
|
||||
Expr lastIndices = shortlist->getIndicesExpr(batchSize, currBeamSize);
|
||||
//std::cerr << "lastIndices=" << lastIndices->shape() << std::endl;
|
||||
//int currBeamSize = sel->shape()[0];
|
||||
//int batchSize = sel->shape()[2];
|
||||
Expr lastIndices = shortlist->getIndicesExpr();
|
||||
//assert(lastIndices->shape()[0] == currBeamSize || lastIndices->shape()[0] == 1);
|
||||
//assert(lastIndices->shape()[1] == batchSize || lastIndices->shape()[1] == 1);
|
||||
|
||||
factorMasks = lambda({lastIndices}, lastIndices->shape(), Type::float32, forward);
|
||||
//std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl;
|
||||
|
||||
const Shape &s = factorMasks->shape();
|
||||
factorMasks = reshape(factorMasks, {s[0], 1, s[1], s[2]});
|
||||
//std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl;
|
||||
}
|
||||
factorMaxima = cast(factorMaxima, sel->value_type());
|
||||
factorMasks = cast(factorMasks, sel->value_type());
|
||||
//std::cerr << "factorMaxima=" << factorMaxima->shape() << std::endl;
|
||||
//std::cerr << "factorMasks.4=" << factorMasks->shape() << std::endl;
|
||||
//std::cerr << "sel.1=" << sel->shape() << std::endl;
|
||||
|
||||
Expr tmp = factorMaxima * factorMasks;
|
||||
//std::cerr << "tmp=" << tmp->shape() << std::endl;
|
||||
sel = sel + tmp; // those lemmas that don't have a factor
|
||||
//std::cerr << "sel.2=" << sel->shape() << std::endl;
|
||||
//std::cerr << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user