clean up bias

This commit is contained in:
Hieu Hoang 2021-06-16 11:19:23 -07:00
parent 892554129e
commit 9b4a845cc7
2 changed files with 18 additions and 14 deletions

View File

@ -184,9 +184,9 @@ void LSHShortlist::createCachedTensors(Expr weights,
cachedShortWt_ = reshape(cachedShortWt_, {currBeamSize, batchSize, k, cachedShortWt_->shape()[1]});
if (b) {
ABORT("Bias not yet tested");
ABORT("Bias not supported with LSH");
cachedShortb_ = index_select(b, -1, indicesExprFlatten);
cachedShortb_ = reshape(cachedShortb_, {currBeamSize, k, batchSize, cachedShortb_->shape()[1]}); // not tested
cachedShortb_ = reshape(cachedShortb_, {currBeamSize, batchSize, k, cachedShortb_->shape()[0]}); // not tested
}
if (lemmaEt) {

View File

@ -56,16 +56,28 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
lazyConstruct(input->shape()[-1]);
auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
/*
std::cerr << "affineOrDot.x=" << x->shape() << std::endl;
std::cerr << "affineOrDot.W=" << W->shape() << std::endl;
std::cerr << "affineOrDot.b=" << b->shape() << std::endl;
std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl;
*/
if(b)
return affine(x, W, b, transA, transB);
else
return dot(x, W, transA, transB);
};
auto affineShortlist = [](Expr x, Expr W, Expr b, bool , bool ) {
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "W=" << W->shape() << std::endl;
Expr ret = bdot(x, W, false, true);
auto affineShortlist = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
/*
std::cerr << "affineShortlist.x=" << x->shape() << std::endl;
std::cerr << "affineShortlist.W=" << W->shape() << std::endl;
std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
*/
ABORT_IF(!(!transA && transB), "Must be transA==0 and transB==1");
ABORT_IF(b, "affineShortlist not tested with bias");
Expr ret = bdot(x, W, transA, transB);
//std::cerr << "ret.2=" << ret->shape() << std::endl;
//std::cerr << std::endl;
@ -171,11 +183,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
// matrix
Expr factorLogits;
if(g == 0 && shortlist_) {
//std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl;
Expr tmp = transpose(input1, {0, 2, 1, 3});
//std::cerr << "tmp=" << tmp->shape() << std::endl;
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl;
factorLogits = affineShortlist(
tmp,
factorWt,
@ -183,18 +191,14 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
factorLogits = transpose(factorLogits, {0, 2, 1, 3});
//std::cerr << "affineShortlist.factorLogits=" << factorLogits->shape() << std::endl << std::endl;
}
else {
//std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl;
//std::cerr << "affineOrDot.factorWt=" << factorWt->shape() << std::endl;
factorLogits = affineOrDot(
input1,
factorWt,
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
//std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl << std::endl;
}
// optionally add lemma-dependent bias