mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
cachedShortWt_ works
This commit is contained in:
parent
d41353eeb7
commit
e518fc9666
@ -32,8 +32,8 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
|
||||
//if (indicesExpr_) return;
|
||||
int currBeamSize = input->shape()[0];
|
||||
int batchSize = input->shape()[2];
|
||||
std::cerr << "currBeamSize=" << currBeamSize << std::endl;
|
||||
std::cerr << "batchSize=" << batchSize << std::endl;
|
||||
//std::cerr << "currBeamSize=" << currBeamSize << std::endl;
|
||||
//std::cerr << "batchSize=" << batchSize << std::endl;
|
||||
|
||||
auto forward = [this](Expr out, const std::vector<Expr>& inputs) {
|
||||
out->val()->set(indices_);
|
||||
@ -91,31 +91,12 @@ void Shortlist::broadcast(Expr weights,
|
||||
//int numHypos = batchSize * currBeamSize;
|
||||
//std::cerr << "batchSize=" << batchSize << std::endl;
|
||||
//std::cerr << "currBeamSize=" << currBeamSize << std::endl;
|
||||
std::cerr << "isLegacyUntransposedW=" << isLegacyUntransposedW << std::endl;
|
||||
ABORT_IF(!isLegacyUntransposedW, "Legacy untranspose W not yet tested");
|
||||
///*
|
||||
std::cerr << "weights=" << weights->shape() << std::endl;
|
||||
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indices());
|
||||
std::cerr << "cachedShortWt_=" << cachedShortWt_->shape() << std::endl;
|
||||
if (b) {
|
||||
ABORT("Bias not yet tested");
|
||||
cachedShortb_ = index_select(b, -1, indices());
|
||||
}
|
||||
//std::cerr << "isLegacyUntransposedW=" << isLegacyUntransposedW << std::endl;
|
||||
ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");
|
||||
|
||||
indicesExprBC = reshape(indicesExprBC, {indicesExprBC->shape().elements()});
|
||||
//std::cerr << "indicesExprBC.2=" << indicesExprBC->shape() << std::endl;
|
||||
|
||||
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprBC);
|
||||
//std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {cachedShortLemmaEt_->shape()[0], batchSize, currBeamSize, k});
|
||||
//std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {2, 0, 1, 3});
|
||||
//std::cerr << "cachedShortLemmaEt.3_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
|
||||
return;
|
||||
//*/
|
||||
|
||||
|
||||
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExprBC);
|
||||
//std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
|
||||
cachedShortWt_ = reshape(cachedShortWt_, {batchSize, currBeamSize, k, cachedShortWt_->shape()[1]});
|
||||
@ -124,10 +105,17 @@ void Shortlist::broadcast(Expr weights,
|
||||
//std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl;
|
||||
|
||||
if (b) {
|
||||
assert(false);
|
||||
ABORT("Bias not yet tested");
|
||||
cachedShortb_ = index_select(b, -1, indicesExprBC);
|
||||
cachedShortb_ = reshape(cachedShortb_, {currBeamSize, k, batchSize, cachedShortb_->shape()[1]}); // not tested
|
||||
}
|
||||
|
||||
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprBC);
|
||||
//std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {cachedShortLemmaEt_->shape()[0], batchSize, currBeamSize, k});
|
||||
//std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {2, 0, 1, 3});
|
||||
//std::cerr << "cachedShortLemmaEt.3_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////////////////////
|
||||
QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,
|
||||
|
@ -62,8 +62,19 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
return dot(x, W, transA, transB);
|
||||
};
|
||||
|
||||
auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
return affineOrDot(x, W, b, transA, transB);
|
||||
auto affineShortlist = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
//std::cerr << "x=" << x->shape() << std::endl;
|
||||
//std::cerr << "W=" << W->shape() << std::endl;
|
||||
//std::cerr << "transA=" << transA << " transB=" << transB << std::endl;
|
||||
|
||||
Expr ret = x * W;
|
||||
ret = sum(ret, 3);
|
||||
//const Shape &retShape = ret->shape();
|
||||
//std::cerr << "ret.1=" << retShape << std::endl;
|
||||
ret = transpose(ret, {0, 3, 2, 1});
|
||||
//ret = reshape(ret, {retShape[0], 1, 1, retShape[2]});
|
||||
//std::cerr << "ret.2=" << ret->shape() << std::endl;
|
||||
return ret;
|
||||
};
|
||||
|
||||
if(shortlist_) { // shortlisted versions of parameters are cached within one
|
||||
@ -164,20 +175,22 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
// @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a
|
||||
// matrix
|
||||
Expr factorLogits;
|
||||
if(g == 0)
|
||||
factorLogits = affineOrLSH(
|
||||
if(g == 0 && shortlist_) {
|
||||
factorLogits = affineShortlist(
|
||||
input1,
|
||||
factorWt,
|
||||
factorB,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
else
|
||||
}
|
||||
else {
|
||||
factorLogits = affineOrDot(
|
||||
input1,
|
||||
factorWt,
|
||||
factorB,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
}
|
||||
|
||||
// optionally add lemma-dependent bias
|
||||
if(Plemma) { // [B... x U0]
|
||||
@ -272,14 +285,14 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
}
|
||||
return Logits(std::move(allLogits), factoredVocab_);
|
||||
} else if(shortlist_) {
|
||||
return Logits(affineOrLSH(input,
|
||||
return Logits(affineOrDot(input,
|
||||
shortlist_->getCachedShortWt(),
|
||||
shortlist_->getCachedShortb(),
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true));
|
||||
} else {
|
||||
return Logits(
|
||||
affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
|
||||
affineOrDot(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user