mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
batch-beam -> beam-batch
This commit is contained in:
parent
5a93c67185
commit
fef7202bc8
@ -129,13 +129,18 @@ LSHShortlist::LSHShortlist(int k, int nbits)
|
||||
//#define BLAS_FOUND 1
|
||||
|
||||
WordIndex LSHShortlist::reverseMap(int batchIdx, int beamIdx, int idx) const {
|
||||
std::cerr << "\nbatchIdx=" << batchIdx << " beamIdx=" << beamIdx << " idx=" << idx << std::endl;
|
||||
std::cerr << "\nbatchIdx=" << batchIdx
|
||||
<< " beamIdx=" << beamIdx
|
||||
<< " idx=" << idx
|
||||
<< " k_=" << k_
|
||||
<< std::endl;
|
||||
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
|
||||
int currBatchSize = indicesExpr_->shape()[0];
|
||||
int currBeamSize = indicesExpr_->shape()[1];
|
||||
int currBeamSize = indicesExpr_->shape()[0];
|
||||
int currBatchSize = indicesExpr_->shape()[1];
|
||||
std::cerr << "currBatchSize=" << currBatchSize << " currBeamSize=" << currBeamSize << std::endl;
|
||||
std::cerr << "indices_=" << indices_.size() << std::endl;
|
||||
idx = (k_ * currBeamSize) * batchIdx + k_ * beamIdx + idx;
|
||||
idx = (k_ * currBatchSize * beamIdx) + (k_ * batchIdx) + idx;
|
||||
//idx = (k_ * currBeamSize * batchIdx) + (k_ * beamIdx) + idx;
|
||||
std::cerr << "idx=" << idx << std::endl;
|
||||
assert(idx < indices_.size());
|
||||
return indices_[idx];
|
||||
@ -152,13 +157,16 @@ WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const {
|
||||
}
|
||||
|
||||
Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const {
|
||||
//std::cerr << "batchSize=" << batchSize << " currBeamSize=" << currBeamSize << std::endl;
|
||||
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << " " << indicesExpr_->val() << std::endl;
|
||||
assert(indicesExpr_->shape()[0] == batchSize);
|
||||
assert(indicesExpr_->shape()[1] == currBeamSize);
|
||||
return indicesExpr_;
|
||||
std::cerr << "batchSize=" << batchSize << " currBeamSize=" << currBeamSize << std::endl;
|
||||
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << " " << indicesExpr_->val() << std::endl;
|
||||
assert(indicesExpr_->shape()[0] == currBeamSize);
|
||||
assert(indicesExpr_->shape()[1] == batchSize);
|
||||
Expr ret = transpose(indicesExpr_, {1, 0, 2});
|
||||
return ret;
|
||||
}
|
||||
|
||||
#define BLAS_FOUND 1
|
||||
|
||||
void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
|
||||
#if BLAS_FOUND
|
||||
static int c = 0;
|
||||
@ -186,6 +194,7 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
|
||||
index_->add( vRows, values->val()->data<float>());
|
||||
}
|
||||
|
||||
std::cerr << "query=" << query->shape() << std::endl;
|
||||
int qRows = query->shape().elements() / dim;
|
||||
std::vector<float> distances(qRows * k_);
|
||||
std::vector<faiss::Index::idx_t> ids(qRows * k_);
|
||||
@ -207,8 +216,8 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
|
||||
out->val()->set(indices_);
|
||||
};
|
||||
|
||||
Shape kShape({batchSize, currBeamSize, k_});
|
||||
//std::cerr << "kShape=" << kShape << std::endl;
|
||||
Shape kShape({currBeamSize, batchSize, k_});
|
||||
std::cerr << "kShape=" << kShape << std::endl;
|
||||
|
||||
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
|
||||
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
|
||||
@ -227,9 +236,9 @@ void LSHShortlist::broadcast(Expr weights,
|
||||
Expr lemmaEt,
|
||||
Expr indicesExprBC,
|
||||
int k) {
|
||||
//std::cerr << "indicesExprBC.0=" << indicesExprBC->shape() << std::endl;
|
||||
int batchSize = indicesExprBC->shape()[0];
|
||||
int currBeamSize = indicesExprBC->shape()[1];
|
||||
std::cerr << "indicesExprBC.0=" << indicesExprBC->shape() << std::endl;
|
||||
int currBeamSize = indicesExprBC->shape()[0];
|
||||
int batchSize = indicesExprBC->shape()[1];
|
||||
//int numHypos = batchSize * currBeamSize;
|
||||
//std::cerr << "batchSize=" << batchSize << std::endl;
|
||||
//std::cerr << "currBeamSize=" << currBeamSize << std::endl;
|
||||
@ -239,14 +248,12 @@ void LSHShortlist::broadcast(Expr weights,
|
||||
indicesExprBC = reshape(indicesExprBC, {indicesExprBC->shape().elements()});
|
||||
//std::cerr << "indicesExprBC.2=" << indicesExprBC->shape() << std::endl;
|
||||
|
||||
//std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
|
||||
//std::cerr << "weights=" << weights->shape() << std::endl;
|
||||
std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
|
||||
std::cerr << "weights=" << weights->shape() << std::endl;
|
||||
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExprBC);
|
||||
//std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
|
||||
cachedShortWt_ = reshape(cachedShortWt_, {batchSize, currBeamSize, k, cachedShortWt_->shape()[1]});
|
||||
//std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;
|
||||
cachedShortWt_ = transpose(cachedShortWt_, {1, 0, 2, 3});
|
||||
//std::cerr << "cachedShortWt_.3=" << cachedShortWt_->shape() << std::endl;
|
||||
std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
|
||||
cachedShortWt_ = reshape(cachedShortWt_, {currBeamSize, batchSize, k, cachedShortWt_->shape()[1]});
|
||||
std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;
|
||||
|
||||
if (b) {
|
||||
ABORT("Bias not yet tested");
|
||||
|
@ -8,6 +8,15 @@ Logits::Logits(Expr logits)
|
||||
: Logits(New<RationalLoss>(logits, nullptr)) {
|
||||
} // single-output constructor from Expr only (RationalLoss has no count)
|
||||
|
||||
Logits::Logits(Ptr<RationalLoss> logits) { // single-output constructor
|
||||
logits_.push_back(logits);
|
||||
}
|
||||
|
||||
Logits::Logits(std::vector<Ptr<RationalLoss>>&& logits,
|
||||
Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
|
||||
: logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {
|
||||
}
|
||||
|
||||
Ptr<ExpressionGraph> Logits::graph() const {
|
||||
ABORT_IF(logits_.empty(), "Empty logits object??");
|
||||
return logits_.front()->loss()->graph();
|
||||
@ -53,6 +62,7 @@ Expr Logits::applyLossFunction(
|
||||
auto factorIndices = indices(maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
|
||||
auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
|
||||
auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
|
||||
std::cerr << "g=" << g << " factorLogits->loss()=" << factorLogits->loss()->shape() << std::endl;
|
||||
// For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
|
||||
auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
|
||||
// clang-format on
|
||||
@ -85,12 +95,14 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
|
||||
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
|
||||
|
||||
auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
|
||||
std::cerr << "sel.1=" << sel->shape() << std::endl;
|
||||
|
||||
// normalize for decoding:
|
||||
// - all secondary factors: subtract their max
|
||||
// - lemma: add all maxes of applicable factors
|
||||
if(groupIndex > 0) {
|
||||
sel = sel - max(sel, -1);
|
||||
std::cerr << "sel.2=" << sel->shape() << std::endl;
|
||||
} else {
|
||||
auto numGroups = getNumFactorGroups();
|
||||
for(size_t g = 1; g < numGroups; g++) {
|
||||
@ -101,7 +113,7 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
|
||||
factorMasks = constant(getFactorMasks(g, std::vector<WordIndex>()));
|
||||
}
|
||||
else {
|
||||
//std::cerr << "sel=" << sel->shape() << std::endl;
|
||||
std::cerr << "sel.3=" << sel->shape() << std::endl;
|
||||
auto forward = [this, g](Expr out, const std::vector<Expr>& inputs) {
|
||||
Expr lastIndices = inputs[0];
|
||||
std::vector<float> masks = getFactorMasksMultiDim(g, lastIndices);
|
||||
@ -111,20 +123,27 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
|
||||
int currBeamSize = sel->shape()[0];
|
||||
int batchSize = sel->shape()[2];
|
||||
Expr lastIndices = shortlist->getIndicesExpr(batchSize, currBeamSize);
|
||||
//std::cerr << "lastIndices=" << lastIndices->shape() << std::endl;
|
||||
std::cerr << "lastIndices=" << lastIndices->shape() << std::endl;
|
||||
factorMasks = lambda({lastIndices}, lastIndices->shape(), Type::float32, forward);
|
||||
//std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl;
|
||||
std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl;
|
||||
factorMasks = transpose(factorMasks, {1, 0, 2});
|
||||
//std::cerr << "factorMasks.2=" << factorMasks->shape() << std::endl;
|
||||
std::cerr << "factorMasks.2=" << factorMasks->shape() << std::endl;
|
||||
|
||||
const Shape &s = factorMasks->shape();
|
||||
factorMasks = reshape(factorMasks, {s[0], 1, s[1], s[2]});
|
||||
//std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl;
|
||||
std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl;
|
||||
}
|
||||
factorMaxima = cast(factorMaxima, sel->value_type());
|
||||
std::cerr << "factorMaxima=" << factorMaxima->shape() << std::endl;
|
||||
factorMasks = cast(factorMasks, sel->value_type());
|
||||
sel = sel + factorMaxima * factorMasks; // those lemmas that don't have a factor
|
||||
std::cerr << "factorMasks.4=" << factorMasks->shape() << std::endl;
|
||||
|
||||
Expr tmp = factorMaxima * factorMasks;
|
||||
std::cerr << "tmp=" << tmp->shape() << std::endl;
|
||||
std::cerr << "sel.4=" << sel->shape() << std::endl;
|
||||
sel = sel + tmp; // those lemmas that don't have a factor
|
||||
// get multiplied with 0
|
||||
std::cerr << "sel.5=" << sel->shape() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -17,14 +17,11 @@ class RationalLoss;
|
||||
class Logits {
|
||||
public:
|
||||
Logits() {}
|
||||
explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
|
||||
logits_.push_back(logits);
|
||||
}
|
||||
explicit Logits(
|
||||
Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
|
||||
explicit Logits(Ptr<RationalLoss> logits); // single-output constructor
|
||||
explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
|
||||
Logits(std::vector<Ptr<RationalLoss>>&& logits,
|
||||
Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
|
||||
: logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {}
|
||||
Ptr<FactoredVocab> embeddingFactorMapping); // factored-output constructor
|
||||
|
||||
Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
|
||||
Expr getFactoredLogits(
|
||||
size_t groupIndex,
|
||||
|
@ -63,9 +63,6 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
};
|
||||
|
||||
auto affineShortlist = [](Expr x, Expr W, Expr b, bool , bool ) {
|
||||
//std::cerr << "x=" << x->shape() << std::endl;
|
||||
//std::cerr << "W=" << W->shape() << std::endl;
|
||||
x = transpose(x, {0, 2, 1, 3});
|
||||
//std::cerr << "x=" << x->shape() << std::endl;
|
||||
//std::cerr << "W=" << W->shape() << std::endl;
|
||||
Expr ret = bdot(x, W, false, true);
|
||||
@ -174,29 +171,35 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
// matrix
|
||||
Expr factorLogits;
|
||||
if(g == 0 && shortlist_) {
|
||||
//std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl;
|
||||
//std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl;
|
||||
Expr tmp = transpose(input1, {0, 2, 1, 3});
|
||||
//std::cerr << "x=" << x->shape() << std::endl;
|
||||
//std::cerr << "W=" << W->shape() << std::endl;
|
||||
factorLogits = affineShortlist(
|
||||
input1,
|
||||
tmp,
|
||||
factorWt,
|
||||
factorB,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
//std::cerr << "affineShortlist.factorLogits.1=" << factorLogits->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.factorLogits.1=" << factorLogits->shape() << std::endl;
|
||||
factorLogits = transpose(factorLogits, {0, 2, 1, 3});
|
||||
//std::cerr << "affineShortlist.factorLogits.2=" << factorLogits->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.factorLogits.2=" << factorLogits->shape() << std::endl;
|
||||
}
|
||||
else {
|
||||
//std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl;
|
||||
//std::cerr << "affineOrDot.factorWt=" << factorWt->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.factorWt.1=" << factorWt->shape() << std::endl;
|
||||
//factorWt = transpose(factorWt, {1, 0, 2, 3});
|
||||
//std::cerr << "affineOrDot.factorWt.2=" << factorWt->shape() << std::endl;
|
||||
factorLogits = affineOrDot(
|
||||
input1,
|
||||
factorWt,
|
||||
factorB,
|
||||
false,
|
||||
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
|
||||
//std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl;
|
||||
std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl;
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
|
||||
// optionally add lemma-dependent bias
|
||||
if(Plemma) { // [B... x U0]
|
||||
@ -210,6 +213,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
|
||||
factorLogits = factorLogits + b;
|
||||
}
|
||||
//std::cerr << "factorLogits=" << factorLogits->shape() << std::endl;
|
||||
allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
|
||||
// optionally add a soft embedding of lemma back to create some lemma dependency
|
||||
// @TODO: if this works, move it into lazyConstruct
|
||||
|
Loading…
Reference in New Issue
Block a user