mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
start using bdot
This commit is contained in:
parent
bc4ad2408c
commit
0949a4c914
@ -66,13 +66,17 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
//std::cerr << "x=" << x->shape() << std::endl;
|
||||
//std::cerr << "W=" << W->shape() << std::endl;
|
||||
//std::cerr << "transA=" << transA << " transB=" << transB << std::endl;
|
||||
|
||||
/*
|
||||
Expr ret = x * W;
|
||||
ret = sum(ret, 3);
|
||||
//const Shape &retShape = ret->shape();
|
||||
//std::cerr << "ret.1=" << retShape << std::endl;
|
||||
ret = transpose(ret, {0, 3, 2, 1});
|
||||
//ret = reshape(ret, {retShape[0], 1, 1, retShape[2]});
|
||||
*/
|
||||
x = transpose(x, {0, 2, 1, 3});
|
||||
W = transpose(W, {0, 2, 1, 3});
|
||||
Expr ret = bdot(x, W, false, true);
|
||||
|
||||
//std::cerr << "ret.2=" << ret->shape() << std::endl;
|
||||
return ret;
|
||||
};
|
||||
@ -258,10 +262,10 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
const Shape &s = lemmaEt_->shape();
|
||||
cachedShortLemmaEt = reshape(lemmaEt_, {1, s[0], 1, s[1]});
|
||||
}
|
||||
//std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
|
||||
//std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
|
||||
Expr e = factorSoftmax * cachedShortLemmaEt;
|
||||
std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
|
||||
std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
|
||||
/*
|
||||
Expr e = factorSoftmax * cachedShortLemmaEt;
|
||||
factorSoftmax= beam x 1 x batch x vocab
|
||||
cachedShortLemmaEt= 1 x 10 x 1 x vocab
|
||||
e= beam x 10 x batch x vocab
|
||||
@ -270,11 +274,21 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
|
||||
std::cerr << "e=" << e->shape() << std::endl;
|
||||
std::cerr << std::endl;
|
||||
*/
|
||||
e = sum(e, 3);
|
||||
//std::cerr << "e.2=" << e->shape() << std::endl;
|
||||
e = transpose(e, {0, 3, 2, 1});
|
||||
//std::cerr << "e.3=" << e->shape() << std::endl;
|
||||
*/
|
||||
factorSoftmax = transpose(factorSoftmax, {0, 2, 1, 3});
|
||||
cachedShortLemmaEt = transpose(cachedShortLemmaEt, {0, 2, 1, 3});
|
||||
std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
|
||||
std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
|
||||
|
||||
Expr e = bdot(factorSoftmax, cachedShortLemmaEt, false, true);
|
||||
std::cerr << "e.1=" << e->shape() << std::endl;
|
||||
const Shape &eShape = e->shape();
|
||||
e = reshape(e, {eShape[0], 1, eShape[1], eShape[3]});
|
||||
std::cerr << "e.3=" << e->shape() << std::endl;
|
||||
std::cerr << std::endl;
|
||||
|
||||
// project it back to regular hidden dim
|
||||
int inputDim = input1->shape()[-1];
|
||||
|
Loading…
Reference in New Issue
Block a user