mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
rename broadcast -> createCachedTensors
This commit is contained in:
parent
cc295938ce
commit
dffbb47eea
@ -48,7 +48,7 @@ void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Exp
|
||||
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
|
||||
|
||||
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
|
||||
broadcast(weights, isLegacyUntransposedW, b, lemmaEt, k);
|
||||
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k);
|
||||
done_ = true;
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ Expr Shortlist::getIndicesExpr(int batchSize, int beamSize) const {
|
||||
return out;
|
||||
}
|
||||
|
||||
void Shortlist::broadcast(Expr weights,
|
||||
void Shortlist::createCachedTensors(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
@ -182,7 +182,7 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
|
||||
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
|
||||
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
|
||||
|
||||
broadcast(weights, isLegacyUntransposedW, b, lemmaEt, k_);
|
||||
createCachedTensors(weights, isLegacyUntransposedW, b, lemmaEt, k_);
|
||||
|
||||
#else
|
||||
input; weights; isLegacyUntransposedW; b; lemmaEt;
|
||||
@ -190,12 +190,11 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
|
||||
#endif
|
||||
}
|
||||
|
||||
void LSHShortlist::broadcast(Expr weights,
|
||||
void LSHShortlist::createCachedTensors(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
int k) {
|
||||
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
|
||||
int currBeamSize = indicesExpr_->shape()[0];
|
||||
int batchSize = indicesExpr_->shape()[1];
|
||||
//int numHypos = batchSize * currBeamSize;
|
||||
@ -205,13 +204,9 @@ void LSHShortlist::broadcast(Expr weights,
|
||||
ABORT_IF(isLegacyUntransposedW, "Legacy untranspose W not yet tested");
|
||||
|
||||
Expr indicesExprFlatten = reshape(indicesExpr_, {indicesExpr_->shape().elements()});
|
||||
std::cerr << "indicesExprFlatten=" << indicesExprFlatten->shape() << std::endl;
|
||||
|
||||
std::cerr << "weights=" << weights->shape() << std::endl;
|
||||
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExprFlatten);
|
||||
std::cerr << "cachedShortWt.1_=" << cachedShortWt_->shape() << std::endl;
|
||||
cachedShortWt_ = reshape(cachedShortWt_, {currBeamSize, batchSize, k, cachedShortWt_->shape()[1]});
|
||||
std::cerr << "cachedShortWt.2_=" << cachedShortWt_->shape() << std::endl;
|
||||
|
||||
if (b) {
|
||||
ABORT("Bias not yet tested");
|
||||
@ -219,14 +214,10 @@ void LSHShortlist::broadcast(Expr weights,
|
||||
cachedShortb_ = reshape(cachedShortb_, {currBeamSize, k, batchSize, cachedShortb_->shape()[1]}); // not tested
|
||||
}
|
||||
|
||||
std::cerr << "lemmaEt=" << lemmaEt->shape() << std::endl;
|
||||
int dim = lemmaEt->shape()[0];
|
||||
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprFlatten);
|
||||
std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {dim, currBeamSize, batchSize, k});
|
||||
std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {1, 2, 0, 3});
|
||||
std::cerr << "cachedShortLemmaEt.3_=" << cachedShortLemmaEt_->shape() << std::endl;
|
||||
}
|
||||
|
||||
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits)
|
||||
|
@ -31,7 +31,7 @@ protected:
|
||||
Expr cachedShortLemmaEt_;
|
||||
bool done_;
|
||||
|
||||
void broadcast(Expr weights,
|
||||
void createCachedTensors(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
@ -73,7 +73,7 @@ private:
|
||||
|
||||
static Ptr<faiss::IndexLSH> index_;
|
||||
|
||||
void broadcast(Expr weights,
|
||||
void createCachedTensors(Expr weights,
|
||||
bool isLegacyUntransposedW,
|
||||
Expr b,
|
||||
Expr lemmaEt,
|
||||
|
Loading…
Reference in New Issue
Block a user