mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
clean up bias
This commit is contained in:
parent
892554129e
commit
9b4a845cc7
@ -184,9 +184,9 @@ void LSHShortlist::createCachedTensors(Expr weights,
|
||||
cachedShortWt_ = reshape(cachedShortWt_, {currBeamSize, batchSize, k, cachedShortWt_->shape()[1]});
|
||||
|
||||
if (b) {
|
||||
ABORT("Bias not yet tested");
|
||||
ABORT("Bias not supported with LSH");
|
||||
cachedShortb_ = index_select(b, -1, indicesExprFlatten);
|
||||
cachedShortb_ = reshape(cachedShortb_, {currBeamSize, k, batchSize, cachedShortb_->shape()[1]}); // not tested
|
||||
cachedShortb_ = reshape(cachedShortb_, {currBeamSize, batchSize, k, cachedShortb_->shape()[0]}); // not tested
|
||||
}
|
||||
|
||||
if (lemmaEt) {
|
||||
|
@ -56,16 +56,28 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
lazyConstruct(input->shape()[-1]);
|
||||
|
||||
auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
/*
|
||||
std::cerr << "affineOrDot.x=" << x->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.W=" << W->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.b=" << b->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl;
|
||||
*/
|
||||
if(b)
|
||||
return affine(x, W, b, transA, transB);
|
||||
else
|
||||
return dot(x, W, transA, transB);
|
||||
};
|
||||
|
||||
auto affineShortlist = [](Expr x, Expr W, Expr b, bool , bool ) {
|
||||
//std::cerr << "x=" << x->shape() << std::endl;
|
||||
//std::cerr << "W=" << W->shape() << std::endl;
|
||||
Expr ret = bdot(x, W, false, true);
|
||||
auto affineShortlist = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
/*
|
||||
std::cerr << "affineShortlist.x=" << x->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.W=" << W->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
|
||||
*/
|
||||
ABORT_IF(!(!transA && transB), "Must be transA==0 and transB==1");
|
||||
ABORT_IF(b, "affineShortlist not tested with bias");
|
||||
Expr ret = bdot(x, W, transA, transB);
|
||||
|
||||
//std::cerr << "ret.2=" << ret->shape() << std::endl;
|
||||
//std::cerr << std::endl;
|
||||
@ -171,11 +183,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
// matrix
|
||||
Expr factorLogits;
|
||||
if(g == 0 && shortlist_) {
|
||||
//std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl;
|
||||
Expr tmp = transpose(input1, {0, 2, 1, 3});
|
||||
//std::cerr << "tmp=" << tmp->shape() << std::endl;
|
||||
//std::cerr << "x=" << x->shape() << std::endl;
|
||||
//std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl;
|
||||
factorLogits = affineShortlist(
|
||||
tmp,
|
||||
factorWt,
|
||||
@ -183,18 +191,14 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
factorLogits = transpose(factorLogits, {0, 2, 1, 3});
|
||||
//std::cerr << "affineShortlist.factorLogits=" << factorLogits->shape() << std::endl << std::endl;
|
||||
}
|
||||
else {
|
||||
//std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl;
|
||||
//std::cerr << "affineOrDot.factorWt=" << factorWt->shape() << std::endl;
|
||||
factorLogits = affineOrDot(
|
||||
input1,
|
||||
factorWt,
|
||||
factorB,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
//std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl << std::endl;
|
||||
}
|
||||
|
||||
// optionally add lemma-dependent bias
|
||||
|
Loading…
Reference in New Issue
Block a user