This commit is contained in:
Hieu Hoang 2021-06-04 14:21:26 -07:00
parent 28e5e2260a
commit f19ebbae69

View File

@ -261,7 +261,16 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
//std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
//std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
Expr e = factorSoftmax * cachedShortLemmaEt;
//std::cerr << "e.1=" << e->shape() << std::endl;
/*
factorSoftmax= beam x 1 x batch x vocab
cachedShortLemmaEt= 1 x 10 x 1 x vocab
e= beam x 10 x batch x vocab
std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
std::cerr << "e=" << e->shape() << std::endl;
std::cerr << std::endl;
*/
e = sum(e, 3);
//std::cerr << "e.2=" << e->shape() << std::endl;
e = transpose(e, {0, 3, 2, 1});