mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
debug
This commit is contained in:
parent
92c6c07786
commit
e07e0368c9
@ -63,23 +63,15 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
};
|
||||
|
||||
auto affineShortlist = [](Expr x, Expr W, Expr b, bool , bool ) {
|
||||
//std::cerr << "x=" << x->shape() << std::endl;
|
||||
//std::cerr << "W=" << W->shape() << std::endl;
|
||||
//std::cerr << "transA=" << transA << " transB=" << transB << std::endl;
|
||||
/*
|
||||
Expr ret = x * W;
|
||||
ret = sum(ret, 3);
|
||||
//const Shape &retShape = ret->shape();
|
||||
//std::cerr << "ret.1=" << retShape << std::endl;
|
||||
ret = transpose(ret, {0, 3, 2, 1});
|
||||
*/
|
||||
std::cerr << "x=" << x->shape() << std::endl;
|
||||
std::cerr << "W=" << W->shape() << std::endl;
|
||||
x = transpose(x, {0, 2, 1, 3});
|
||||
//std::cerr << "x=" << x->shape() << std::endl;
|
||||
//std::cerr << "W=" << W->shape() << std::endl;
|
||||
std::cerr << "x=" << x->shape() << std::endl;
|
||||
std::cerr << "W=" << W->shape() << std::endl;
|
||||
Expr ret = bdot(x, W, false, true);
|
||||
|
||||
//std::cerr << "ret.2=" << ret->shape() << std::endl;
|
||||
//std::cerr << std::endl;
|
||||
std::cerr << "ret.2=" << ret->shape() << std::endl;
|
||||
std::cerr << std::endl;
|
||||
return ret;
|
||||
};
|
||||
|
||||
@ -182,20 +174,28 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
// matrix
|
||||
Expr factorLogits;
|
||||
if(g == 0 && shortlist_) {
|
||||
//std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl;
|
||||
//std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl;
|
||||
factorLogits = affineShortlist(
|
||||
input1,
|
||||
factorWt,
|
||||
factorB,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
//std::cerr << "affineShortlist.factorLogits.1=" << factorLogits->shape() << std::endl;
|
||||
factorLogits = transpose(factorLogits, {0, 2, 1, 3});
|
||||
//std::cerr << "affineShortlist.factorLogits.2=" << factorLogits->shape() << std::endl;
|
||||
}
|
||||
else {
|
||||
//std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl;
|
||||
//std::cerr << "affineOrDot.factorWt=" << factorWt->shape() << std::endl;
|
||||
factorLogits = affineOrDot(
|
||||
input1,
|
||||
factorWt,
|
||||
factorB,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
//std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl;
|
||||
}
|
||||
|
||||
// optionally add lemma-dependent bias
|
||||
@ -270,6 +270,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
//std::cerr << "cachedShortLemmaEt.2=" << cachedShortLemmaEt->shape() << std::endl;
|
||||
factorSoftmax = transpose(factorSoftmax, {0, 2, 1, 3});
|
||||
//std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
|
||||
//std::cerr << "cachedShortLemmaEt.2=" << cachedShortLemmaEt->shape() << std::endl;
|
||||
|
||||
Expr e = bdot(factorSoftmax, cachedShortLemmaEt, false, true);
|
||||
//std::cerr << "e.1=" << e->shape() << std::endl;
|
||||
|
Loading…
Reference in New Issue
Block a user