mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
filter & broadcast every word. SL works
This commit is contained in:
parent
e07e0368c9
commit
5d1946ebd3
@ -35,9 +35,9 @@ WordIndex Shortlist::tryForwardMap(int , int , WordIndex wIdx) const {
|
||||
}
|
||||
|
||||
void Shortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
|
||||
if (done_) {
|
||||
return;
|
||||
}
|
||||
//if (done_) {
|
||||
// return;
|
||||
//}
|
||||
|
||||
//if (indicesExpr_) return;
|
||||
int currBeamSize = input->shape()[0];
|
||||
@ -109,12 +109,14 @@ void Shortlist::broadcast(Expr weights,
|
||||
indicesExprBC = reshape(indicesExprBC, {indicesExprBC->shape().elements()});
|
||||
//std::cerr << "indicesExprBC.2=" << indicesExprBC->shape() << std::endl;
|
||||
|
||||
std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
|
||||
std::cerr << "weights=" << weights->shape() << std::endl;
|
||||
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExprBC);
|
||||
//std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
|
||||
std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
|
||||
cachedShortWt_ = reshape(cachedShortWt_, {batchSize, currBeamSize, k, cachedShortWt_->shape()[1]});
|
||||
//std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;
|
||||
std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;
|
||||
cachedShortWt_ = transpose(cachedShortWt_, {1, 0, 2, 3});
|
||||
//std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl;
|
||||
std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl;
|
||||
|
||||
if (b) {
|
||||
ABORT("Bias not yet tested");
|
||||
|
Loading…
Reference in New Issue
Block a user