cachedShortWt_ works

This commit is contained in:
Hieu Hoang 2021-04-29 11:59:25 -07:00
parent d41353eeb7
commit e518fc9666
2 changed files with 32 additions and 31 deletions

View File

@ -32,8 +32,8 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
//if (indicesExpr_) return;
int currBeamSize = input->shape()[0];
int batchSize = input->shape()[2];
std::cerr << "currBeamSize=" << currBeamSize << std::endl;
std::cerr << "batchSize=" << batchSize << std::endl;
//std::cerr << "currBeamSize=" << currBeamSize << std::endl;
//std::cerr << "batchSize=" << batchSize << std::endl;
auto forward = [this](Expr out, const std::vector<Expr>& inputs) {
out->val()->set(indices_);
@ -91,31 +91,12 @@ void Shortlist::broadcast(Expr weights,
//int numHypos = batchSize * currBeamSize;
//std::cerr << "batchSize=" << batchSize << std::endl;
//std::cerr << "currBeamSize=" << currBeamSize << std::endl;
std::cerr << "isLegacyUntransposedW=" << isLegacyUntransposedW << std::endl;
ABORT_IF(!isLegacyUntransposedW, "Legacy untranspose W not yet tested");
///*
std::cerr << "weights=" << weights->shape() << std::endl;
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indices());
std::cerr << "cachedShortWt_=" << cachedShortWt_->shape() << std::endl;
if (b) {
ABORT("Bias not yet tested");
cachedShortb_ = index_select(b, -1, indices());
}
//std::cerr << "isLegacyUntransposedW=" << isLegacyUntransposedW << std::endl;
ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");
indicesExprBC = reshape(indicesExprBC, {indicesExprBC->shape().elements()});
//std::cerr << "indicesExprBC.2=" << indicesExprBC->shape() << std::endl;
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprBC);
//std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl;
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {cachedShortLemmaEt_->shape()[0], batchSize, currBeamSize, k});
//std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl;
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {2, 0, 1, 3});
//std::cerr << "cachedShortLemmaEt.3_=" << cachedShortLemmaEt_->shape() << std::endl;
return;
//*/
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExprBC);
//std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
cachedShortWt_ = reshape(cachedShortWt_, {batchSize, currBeamSize, k, cachedShortWt_->shape()[1]});
@ -124,10 +105,17 @@ void Shortlist::broadcast(Expr weights,
//std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl;
if (b) {
assert(false);
ABORT("Bias not yet tested");
cachedShortb_ = index_select(b, -1, indicesExprBC);
cachedShortb_ = reshape(cachedShortb_, {currBeamSize, k, batchSize, cachedShortb_->shape()[1]}); // not tested
}
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprBC);
//std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl;
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {cachedShortLemmaEt_->shape()[0], batchSize, currBeamSize, k});
//std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl;
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {2, 0, 1, 3});
//std::cerr << "cachedShortLemmaEt.3_=" << cachedShortLemmaEt_->shape() << std::endl;
}
//////////////////////////////////////////////////////////////////////////////////////
QuicksandShortlistGenerator::QuicksandShortlistGenerator(Ptr<Options> options,

View File

@ -62,8 +62,19 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
return dot(x, W, transA, transB);
};
auto affineOrLSH = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
return affineOrDot(x, W, b, transA, transB);
auto affineShortlist = [this, affineOrDot](Expr x, Expr W, Expr b, bool transA, bool transB) {
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "W=" << W->shape() << std::endl;
//std::cerr << "transA=" << transA << " transB=" << transB << std::endl;
Expr ret = x * W;
ret = sum(ret, 3);
//const Shape &retShape = ret->shape();
//std::cerr << "ret.1=" << retShape << std::endl;
ret = transpose(ret, {0, 3, 2, 1});
//ret = reshape(ret, {retShape[0], 1, 1, retShape[2]});
//std::cerr << "ret.2=" << ret->shape() << std::endl;
return ret;
};
if(shortlist_) { // shortlisted versions of parameters are cached within one
@ -164,20 +175,22 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
// @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a
// matrix
Expr factorLogits;
if(g == 0)
factorLogits = affineOrLSH(
if(g == 0 && shortlist_) {
factorLogits = affineShortlist(
input1,
factorWt,
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
else
}
else {
factorLogits = affineOrDot(
input1,
factorWt,
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
}
// optionally add lemma-dependent bias
if(Plemma) { // [B... x U0]
@ -272,14 +285,14 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
}
return Logits(std::move(allLogits), factoredVocab_);
} else if(shortlist_) {
return Logits(affineOrLSH(input,
return Logits(affineOrDot(input,
shortlist_->getCachedShortWt(),
shortlist_->getCachedShortb(),
false,
/*transB=*/isLegacyUntransposedW ? false : true));
} else {
return Logits(
affineOrLSH(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
affineOrDot(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
}
}