Merged PR 19864: add bias if it exists

Fixes backcompat with shortlist and bias.
This commit is contained in:
Hieu Hoang 2021-07-21 00:12:02 +00:00 committed by Martin Junczys-Dowmunt
parent 056c4bef5b
commit f6cb1b5c6a
2 changed files with 26 additions and 9 deletions

View File

@ -43,6 +43,7 @@ public:
Shortlist(const std::vector<WordIndex>& indices);
virtual ~Shortlist();
virtual bool isDynamic() const { return false; }
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const;
virtual WordIndex tryForwardMap(WordIndex wIdx) const;
@ -87,6 +88,8 @@ private:
public:
LSHShortlist(int k, int nbits, size_t lemmaSize, bool abortIfDynamic = false);
virtual bool isDynamic() const override { return true; }
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;

View File

@ -59,7 +59,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
/*
std::cerr << "affineOrDot.x=" << x->shape() << std::endl;
std::cerr << "affineOrDot.W=" << W->shape() << std::endl;
std::cerr << "affineOrDot.b=" << b->shape() << std::endl;
if (b) std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl;
*/
if(b)
@ -68,18 +68,32 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
return dot(x, W, transA, transB);
};
auto affineShortlist = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
/*
auto affineShortlist = [this](Expr x, Expr W, Expr b, bool transA, bool transB) {
/*
std::cerr << "affineShortlist.x=" << x->shape() << std::endl;
std::cerr << "affineShortlist.W=" << W->shape() << std::endl;
std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
if (b) std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
*/
ABORT_IF(!(!transA && transB), "affineShortlist. Must be transA==0 and transB==1");
ABORT_IF(b, "affineShortlist not tested with bias");
Expr ret = bdot(x, W, transA, transB);
//std::cerr << "ret=" << ret->shape() << std::endl;
//std::cerr << std::endl;
Expr ret;
if (b) {
// original shortlist. W always has 1 for beam & batch
ABORT_UNLESS(!shortlist_->isDynamic(), "affineShortlist. Bias not supported with LSH/dynamic shortlist"); // todo rename ABORT_UNLESS to ASSERT
ret = affine(x, W, b, transA, transB);
}
else if (shortlist_->isDynamic()) {
// LSH produces W entry for each beam and batch => need bdot()
ABORT_IF(!(!transA && transB), "affineShortlist. Only tested with transA==0 and transB==1");
ret = bdot(x, W, transA, transB);
}
else {
// original shortlist. W always has 1 for beam & batch
ret = dot(x, W, transA, transB);
}
//std::cerr << "ret.x=" << ret->shape() << std::endl;
return ret;
};