This commit is contained in:
Hieu Hoang 2021-06-08 18:20:56 -07:00
parent 92c6c07786
commit e07e0368c9

View File

@ -63,23 +63,15 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
};
auto affineShortlist = [](Expr x, Expr W, Expr b, bool , bool ) {
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "W=" << W->shape() << std::endl;
//std::cerr << "transA=" << transA << " transB=" << transB << std::endl;
/*
Expr ret = x * W;
ret = sum(ret, 3);
//const Shape &retShape = ret->shape();
//std::cerr << "ret.1=" << retShape << std::endl;
ret = transpose(ret, {0, 3, 2, 1});
*/
std::cerr << "x=" << x->shape() << std::endl;
std::cerr << "W=" << W->shape() << std::endl;
x = transpose(x, {0, 2, 1, 3});
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "W=" << W->shape() << std::endl;
std::cerr << "x=" << x->shape() << std::endl;
std::cerr << "W=" << W->shape() << std::endl;
Expr ret = bdot(x, W, false, true);
//std::cerr << "ret.2=" << ret->shape() << std::endl;
//std::cerr << std::endl;
std::cerr << "ret.2=" << ret->shape() << std::endl;
std::cerr << std::endl;
return ret;
};
@ -182,20 +174,28 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
// matrix
Expr factorLogits;
if(g == 0 && shortlist_) {
//std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl;
//std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl;
factorLogits = affineShortlist(
input1,
factorWt,
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
//std::cerr << "affineShortlist.factorLogits.1=" << factorLogits->shape() << std::endl;
factorLogits = transpose(factorLogits, {0, 2, 1, 3});
//std::cerr << "affineShortlist.factorLogits.2=" << factorLogits->shape() << std::endl;
}
else {
//std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl;
//std::cerr << "affineOrDot.factorWt=" << factorWt->shape() << std::endl;
factorLogits = affineOrDot(
input1,
factorWt,
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
//std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl;
}
// optionally add lemma-dependent bias
@ -270,6 +270,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
//std::cerr << "cachedShortLemmaEt.2=" << cachedShortLemmaEt->shape() << std::endl;
factorSoftmax = transpose(factorSoftmax, {0, 2, 1, 3});
//std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
//std::cerr << "cachedShortLemmaEt.2=" << cachedShortLemmaEt->shape() << std::endl;
Expr e = bdot(factorSoftmax, cachedShortLemmaEt, false, true);
//std::cerr << "e.1=" << e->shape() << std::endl;