mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 19864: add bias if it exists
Fixes backcompat with shortlist and bias.
This commit is contained in:
parent
056c4bef5b
commit
f6cb1b5c6a
@ -43,6 +43,7 @@ public:
|
||||
Shortlist(const std::vector<WordIndex>& indices);
|
||||
virtual ~Shortlist();
|
||||
|
||||
virtual bool isDynamic() const { return false; }
|
||||
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const;
|
||||
virtual WordIndex tryForwardMap(WordIndex wIdx) const;
|
||||
|
||||
@ -87,6 +88,8 @@ private:
|
||||
|
||||
public:
|
||||
LSHShortlist(int k, int nbits, size_t lemmaSize, bool abortIfDynamic = false);
|
||||
|
||||
virtual bool isDynamic() const override { return true; }
|
||||
virtual WordIndex reverseMap(int beamIdx, int batchIdx, int idx) const override;
|
||||
|
||||
virtual void filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) override;
|
||||
|
@ -59,7 +59,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
/*
|
||||
std::cerr << "affineOrDot.x=" << x->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.W=" << W->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.b=" << b->shape() << std::endl;
|
||||
if (b) std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl;
|
||||
*/
|
||||
if(b)
|
||||
@ -68,18 +68,32 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
return dot(x, W, transA, transB);
|
||||
};
|
||||
|
||||
auto affineShortlist = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
/*
|
||||
auto affineShortlist = [this](Expr x, Expr W, Expr b, bool transA, bool transB) {
|
||||
/*
|
||||
std::cerr << "affineShortlist.x=" << x->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.W=" << W->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
|
||||
if (b) std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
|
||||
*/
|
||||
ABORT_IF(!(!transA && transB), "affineShortlist. Must be transA==0 and transB==1");
|
||||
ABORT_IF(b, "affineShortlist not tested with bias");
|
||||
Expr ret = bdot(x, W, transA, transB);
|
||||
//std::cerr << "ret=" << ret->shape() << std::endl;
|
||||
//std::cerr << std::endl;
|
||||
|
||||
Expr ret;
|
||||
|
||||
if (b) {
|
||||
// original shortlist. W always has 1 for beam & batch
|
||||
ABORT_UNLESS(!shortlist_->isDynamic(), "affineShortlist. Bias not supported with LSH/dynamic shortlist"); // todo rename ABORT_UNLESS to ASSERT
|
||||
ret = affine(x, W, b, transA, transB);
|
||||
}
|
||||
else if (shortlist_->isDynamic()) {
|
||||
// LSH produces W entry for each beam and batch => need bdot()
|
||||
ABORT_IF(!(!transA && transB), "affineShortlist. Only tested with transA==0 and transB==1");
|
||||
ret = bdot(x, W, transA, transB);
|
||||
}
|
||||
else {
|
||||
// original shortlist. W always has 1 for beam & batch
|
||||
ret = dot(x, W, transA, transB);
|
||||
}
|
||||
|
||||
//std::cerr << "ret.x=" << ret->shape() << std::endl;
|
||||
return ret;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user