This commit is contained in:
Hieu Hoang 2021-06-11 00:57:02 -07:00
parent fef7202bc8
commit f0251889f2
4 changed files with 1 additions and 46 deletions

View File

@ -129,19 +129,9 @@ LSHShortlist::LSHShortlist(int k, int nbits)
//#define BLAS_FOUND 1
WordIndex LSHShortlist::reverseMap(int batchIdx, int beamIdx, int idx) const {
std::cerr << "\nbatchIdx=" << batchIdx
<< " beamIdx=" << beamIdx
<< " idx=" << idx
<< " k_=" << k_
<< std::endl;
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
int currBeamSize = indicesExpr_->shape()[0];
//int currBeamSize = indicesExpr_->shape()[0];
int currBatchSize = indicesExpr_->shape()[1];
std::cerr << "currBatchSize=" << currBatchSize << " currBeamSize=" << currBeamSize << std::endl;
std::cerr << "indices_=" << indices_.size() << std::endl;
idx = (k_ * currBatchSize * beamIdx) + (k_ * batchIdx) + idx;
//idx = (k_ * currBeamSize * batchIdx) + (k_ * beamIdx) + idx;
std::cerr << "idx=" << idx << std::endl;
assert(idx < indices_.size());
return indices_[idx];
}
@ -157,8 +147,6 @@ WordIndex LSHShortlist::tryForwardMap(int , int , WordIndex wIdx) const {
}
Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const {
std::cerr << "batchSize=" << batchSize << " currBeamSize=" << currBeamSize << std::endl;
std::cerr << "indicesExpr_=" << indicesExpr_->shape() << " " << indicesExpr_->val() << std::endl;
assert(indicesExpr_->shape()[0] == currBeamSize);
assert(indicesExpr_->shape()[1] == batchSize);
Expr ret = transpose(indicesExpr_, {1, 0, 2});
@ -169,8 +157,6 @@ Expr LSHShortlist::getIndicesExpr(int batchSize, int currBeamSize) const {
void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW, Expr b, Expr lemmaEt) {
#if BLAS_FOUND
static int c = 0;
std::cerr << "c=" << c++ << std::endl;
ABORT_IF(input->graph()->getDeviceId().type == DeviceType::gpu,
"LSH index (--output-approx-knn) currently not implemented for GPU");
@ -194,7 +180,6 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
index_->add( vRows, values->val()->data<float>());
}
std::cerr << "query=" << query->shape() << std::endl;
int qRows = query->shape().elements() / dim;
std::vector<float> distances(qRows * k_);
std::vector<faiss::Index::idx_t> ids(qRows * k_);
@ -217,7 +202,6 @@ void LSHShortlist::filter(Expr input, Expr weights, bool isLegacyUntransposedW,
};
Shape kShape({currBeamSize, batchSize, k_});
std::cerr << "kShape=" << kShape << std::endl;
indicesExpr_ = lambda({input, weights}, kShape, Type::uint32, forward);
//std::cerr << "indicesExpr_=" << indicesExpr_->shape() << std::endl;
@ -236,7 +220,6 @@ void LSHShortlist::broadcast(Expr weights,
Expr lemmaEt,
Expr indicesExprBC,
int k) {
std::cerr << "indicesExprBC.0=" << indicesExprBC->shape() << std::endl;
int currBeamSize = indicesExprBC->shape()[0];
int batchSize = indicesExprBC->shape()[1];
//int numHypos = batchSize * currBeamSize;
@ -248,12 +231,8 @@ void LSHShortlist::broadcast(Expr weights,
indicesExprBC = reshape(indicesExprBC, {indicesExprBC->shape().elements()});
//std::cerr << "indicesExprBC.2=" << indicesExprBC->shape() << std::endl;
std::cerr << "currBeamSize=" << currBeamSize << " batchSize=" << batchSize << std::endl;
std::cerr << "weights=" << weights->shape() << std::endl;
cachedShortWt_ = index_select(weights, isLegacyUntransposedW ? -1 : 0, indicesExprBC);
std::cerr << "cachedShortWt_.1=" << cachedShortWt_->shape() << std::endl;
cachedShortWt_ = reshape(cachedShortWt_, {currBeamSize, batchSize, k, cachedShortWt_->shape()[1]});
std::cerr << "cachedShortWt_.2=" << cachedShortWt_->shape() << std::endl;
if (b) {
ABORT("Bias not yet tested");
@ -262,11 +241,8 @@ void LSHShortlist::broadcast(Expr weights,
}
cachedShortLemmaEt_ = index_select(lemmaEt, -1, indicesExprBC);
//std::cerr << "cachedShortLemmaEt.1_=" << cachedShortLemmaEt_->shape() << std::endl;
cachedShortLemmaEt_ = reshape(cachedShortLemmaEt_, {cachedShortLemmaEt_->shape()[0], batchSize, currBeamSize, k});
//std::cerr << "cachedShortLemmaEt.2_=" << cachedShortLemmaEt_->shape() << std::endl;
cachedShortLemmaEt_ = transpose(cachedShortLemmaEt_, {2, 1, 0, 3});
//std::cerr << "cachedShortLemmaEt.3_=" << cachedShortLemmaEt_->shape() << std::endl;
}
LSHShortlistGenerator::LSHShortlistGenerator(int k, int nbits)

View File

@ -95,14 +95,12 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
std::cerr << "sel.1=" << sel->shape() << std::endl;
// normalize for decoding:
// - all secondary factors: subtract their max
// - lemma: add all maxes of applicable factors
if(groupIndex > 0) {
sel = sel - max(sel, -1);
std::cerr << "sel.2=" << sel->shape() << std::endl;
} else {
auto numGroups = getNumFactorGroups();
for(size_t g = 1; g < numGroups; g++) {
@ -113,7 +111,6 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
factorMasks = constant(getFactorMasks(g, std::vector<WordIndex>()));
}
else {
std::cerr << "sel.3=" << sel->shape() << std::endl;
auto forward = [this, g](Expr out, const std::vector<Expr>& inputs) {
Expr lastIndices = inputs[0];
std::vector<float> masks = getFactorMasksMultiDim(g, lastIndices);
@ -123,27 +120,18 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
int currBeamSize = sel->shape()[0];
int batchSize = sel->shape()[2];
Expr lastIndices = shortlist->getIndicesExpr(batchSize, currBeamSize);
std::cerr << "lastIndices=" << lastIndices->shape() << std::endl;
factorMasks = lambda({lastIndices}, lastIndices->shape(), Type::float32, forward);
std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl;
factorMasks = transpose(factorMasks, {1, 0, 2});
std::cerr << "factorMasks.2=" << factorMasks->shape() << std::endl;
const Shape &s = factorMasks->shape();
factorMasks = reshape(factorMasks, {s[0], 1, s[1], s[2]});
std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl;
}
factorMaxima = cast(factorMaxima, sel->value_type());
std::cerr << "factorMaxima=" << factorMaxima->shape() << std::endl;
factorMasks = cast(factorMasks, sel->value_type());
std::cerr << "factorMasks.4=" << factorMasks->shape() << std::endl;
Expr tmp = factorMaxima * factorMasks;
std::cerr << "tmp=" << tmp->shape() << std::endl;
std::cerr << "sel.4=" << sel->shape() << std::endl;
sel = sel + tmp; // those lemmas that don't have a factor
// get multiplied with 0
std::cerr << "sel.5=" << sel->shape() << std::endl;
}
}

View File

@ -171,8 +171,6 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
// matrix
Expr factorLogits;
if(g == 0 && shortlist_) {
std::cerr << "affineShortlist.input1=" << input1->shape() << std::endl;
std::cerr << "affineShortlist.factorWt=" << factorWt->shape() << std::endl;
Expr tmp = transpose(input1, {0, 2, 1, 3});
//std::cerr << "x=" << x->shape() << std::endl;
//std::cerr << "W=" << W->shape() << std::endl;
@ -182,13 +180,9 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
std::cerr << "affineShortlist.factorLogits.1=" << factorLogits->shape() << std::endl;
factorLogits = transpose(factorLogits, {0, 2, 1, 3});
std::cerr << "affineShortlist.factorLogits.2=" << factorLogits->shape() << std::endl;
}
else {
std::cerr << "affineOrDot.input1=" << input1->shape() << std::endl;
std::cerr << "affineOrDot.factorWt.1=" << factorWt->shape() << std::endl;
//factorWt = transpose(factorWt, {1, 0, 2, 3});
//std::cerr << "affineOrDot.factorWt.2=" << factorWt->shape() << std::endl;
factorLogits = affineOrDot(
@ -197,9 +191,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
factorB,
false,
/*transB=*/isLegacyUntransposedW ? false : true); // [B... x U] factor logits
std::cerr << "affineOrDot.factorLogits=" << factorLogits->shape() << std::endl;
}
std::cerr << std::endl;
// optionally add lemma-dependent bias
if(Plemma) { // [B... x U0]

View File

@ -101,7 +101,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
// For factored decoding, the word is built over multiple decoding steps,
// starting with the lemma, then adding factors one by one.
if (factorGroup == 0) {
std::cerr << "currentBatchId=" << currentBatchIdx << " origBatchIdx=" << origBatchIdx << std::endl;
word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap((int) currentBatchIdx, (int) prevBeamHypIdx, wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
//LOG(info, "{} + {} ({}) -> {} -> {}",