start using bdot

This commit is contained in:
Hieu Hoang 2021-06-07 15:05:56 -07:00
parent bc4ad2408c
commit 0949a4c914

View File

@ -66,13 +66,17 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "W=" << W->shape() << std::endl;
//std::cerr << "transA=" << transA << " transB=" << transB << std::endl;
/*
Expr ret = x * W;
ret = sum(ret, 3);
//const Shape &retShape = ret->shape();
//std::cerr << "ret.1=" << retShape << std::endl;
ret = transpose(ret, {0, 3, 2, 1});
//ret = reshape(ret, {retShape[0], 1, 1, retShape[2]});
*/
x = transpose(x, {0, 2, 1, 3});
W = transpose(W, {0, 2, 1, 3});
Expr ret = bdot(x, W, false, true);
//std::cerr << "ret.2=" << ret->shape() << std::endl;
return ret;
};
@ -258,10 +262,10 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
const Shape &s = lemmaEt_->shape();
cachedShortLemmaEt = reshape(lemmaEt_, {1, s[0], 1, s[1]});
}
//std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
//std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
Expr e = factorSoftmax * cachedShortLemmaEt;
std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
/*
Expr e = factorSoftmax * cachedShortLemmaEt;
factorSoftmax= beam x 1 x batch x vocab
cachedShortLemmaEt= 1 x 10 x 1 x vocab
e= beam x 10 x batch x vocab
@ -270,11 +274,21 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
std::cerr << "e=" << e->shape() << std::endl;
std::cerr << std::endl;
*/
e = sum(e, 3);
//std::cerr << "e.2=" << e->shape() << std::endl;
e = transpose(e, {0, 3, 2, 1});
//std::cerr << "e.3=" << e->shape() << std::endl;
*/
factorSoftmax = transpose(factorSoftmax, {0, 2, 1, 3});
cachedShortLemmaEt = transpose(cachedShortLemmaEt, {0, 2, 1, 3});
std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
Expr e = bdot(factorSoftmax, cachedShortLemmaEt, false, true);
std::cerr << "e.1=" << e->shape() << std::endl;
const Shape &eShape = e->shape();
e = reshape(e, {eShape[0], 1, eShape[1], eShape[3]});
std::cerr << "e.3=" << e->shape() << std::endl;
std::cerr << std::endl;
// project it back to regular hidden dim
int inputDim = input1->shape()[-1];