mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
separate broadcast
This commit is contained in:
parent
5d1946ebd3
commit
0bc9b22b15
@ -238,6 +238,47 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
|
||||
#endif
|
||||
}
|
||||
|
||||
void LSHShortlist::broadcast(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
Expr indicesExprBC,
|
||||
int k) {
|
||||
//std::cerr << "indicesExprBC.0=" << indicesExprBC->shape() << std::endl;
|
||||
int batchSize = indicesExprBC->shape()[0];
|
||||
int currBeamSize = indicesExprBC->shape()[1];
|
||||
//int numHypos = batchSize * currBeamSize;
|
||||
//std::cerr << "batchSize=" << batchSize << std::endl;
|
||||
//std::cerr << "currBeamSize=" << currBeamSize << std::endl;
|
||||
//std::cerr << "isLegacyUntransposedW=" << isLegacyUntransposedW << std::endl;
|
||||
ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");
|
||||
|
||||
indicesExprBC = reshape(indicesExprBC, {indicesExprBC->shape().elements()});
|
||||
//std::cerr << "indicesExprBC.2=" << indicesExprBC->shape() << std::endl;
|
||||
|
||||
std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
|
||||
std::cerr << "weights=" << weights->shape() << std::endl;
|
||||
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExprBC);
|
||||
std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
|
||||
cachedShortWt_ = reshape(cachedShortWt_, {batchSize, currBeamSize, k, cachedShortWt_->shape()[1]});
|
||||
std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;
|
||||
cachedShortWt_ = transpose(cachedShortWt_, {1, 0, 2, 3});
|
||||
std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl;
|
||||
|
||||
if (b) {
|
||||
ABORT("Bias not yet tested");
|
||||
cachedShortb_ = index_select(b, -1, indicesExprBC);
|
||||
cachedShortb_ = reshape(cachedShortb_, {currBeamSize, k, batchSize, cachedShortb_->shape()[1]}); // not tested
|
||||
}
|
||||
|
||||
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprBC);
|
||||
//std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {cachedShortLemmaEt_->shape()[0], batchSize, currBeamSize, k});
|
||||
//std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {2, 1, 0, 3});
|
||||
//std::cerr << "cachedShortLemmaEt.3_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
}
|
||||
|
||||
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits)
|
||||
: k_(k), nbits_(nbits) {
|
||||
//std::cerr << "LSHShortlistGenerator" << std::endl;
|
||||
|
@ -74,6 +74,13 @@ private:
|
||||
|
||||
static Ptr<faiss::IndexLSH> index_;
|
||||
|
||||
virtual void broadcast(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
Expr indicesExprBC,
|
||||
int k) override;
|
||||
|
||||
public:
|
||||
LSHShortlist(int k, int nbits);
|
||||
virtual WordIndex reverseMap(int batchIdx, int beamIdx, int idx) const override;
|
||||
|
Loading…
Reference in New Issue
Block a user