marcin's review changes

This commit is contained in:
Hieu Hoang 2021-07-02 12:06:03 -07:00
parent ff8af52624
commit bd1f1ee9cb
7 changed files with 10 additions and 17 deletions

View File

@ -320,7 +320,8 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
size_t srcIdx,
size_t trgIdx,
bool shared) {
if (lshOpts.size() == 2) {
if (lshOpts.size()) {
assert(lshOpts.size() == 2);
size_t lemmaSize = trgVocab->lemmaSize();
return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1], lemmaSize);
}

View File

@ -66,8 +66,6 @@ public:
};
///////////////////////////////////////////////////////////////////////////////////
// implements SLIDE for faster inference.
// https://arxiv.org/pdf/1903.03129.pdf
class LSHShortlist: public Shortlist {
private:
int k_; // number of candidates returned from each input

View File

@ -62,7 +62,7 @@ Expr Logits::applyLossFunction(
auto factorIndices = indices(maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
std::cerr << "g=" << g << " factorLogits->loss()=" << factorLogits->loss()->shape() << std::endl;
//std::cerr << "g=" << g << " factorLogits->loss()=" << factorLogits->loss()->shape() << std::endl;
// For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
// clang-format on
@ -113,7 +113,7 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
else {
auto forward = [this, g](Expr out, const std::vector<Expr>& inputs) {
Expr lastIndices = inputs[0];
std::vector<float> masks = getFactorMasksMultiDim(g, lastIndices);
std::vector<float> masks = getFactorMasks(g, lastIndices);
out->val()->set(masks);
};
@ -245,7 +245,7 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<
return res;
}
std::vector<float> Logits::getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr)
std::vector<float> Logits::getFactorMasks(size_t factorGroup, Expr indicesExpr)
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
int batchSize = indicesExpr->shape()[0];
int currBeamSize = indicesExpr->shape()[1];

View File

@ -77,7 +77,7 @@ private:
} // actually the same as constant(data) for this data type
std::vector<float> getFactorMasks(size_t factorGroup,
const std::vector<WordIndex>& indices) const;
std::vector<float> getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr) const;
std::vector<float> getFactorMasks(size_t factorGroup, Expr indicesExpr) const; // same as above but separate indices for each batch and beam
private:
// members

View File

@ -75,7 +75,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
*/
ABORT_IF(!(!transA && transB), "Must be transA==0 and transB==1");
ABORT_IF(!(!transA && transB), "affineShortlist. Must be transA==0 and transB==1");
ABORT_IF(b, "affineShortlist not tested with bias");
Expr ret = bdot(x, W, transA, transB);
//std::cerr << "ret=" << ret->shape() << std::endl;
@ -83,8 +83,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
return ret;
};
if(shortlist_) { // shortlisted versions of parameters are cached within one
// batch, then clear()ed
if(shortlist_) {
shortlist_->filter(input, Wt_, isLegacyUntransposedW, b_, lemmaEt_);
}

View File

@ -20,7 +20,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
const std::vector<bool>& dropBatchEntries, // [origDimBatch] - empty source batch entries are marked with true, should be cleared after first use.
const std::vector<IndexType>& batchIdxMap) const { // [origBatchIdx -> currentBatchIdx]
std::vector<float> align; // collects alignment information from the last executed time step
//utils::Debug(batchIdxMap, "batchIdxMap");
if(options_->hasAndNotEmpty("alignment") && factorGroup == 0)
align = scorers_[0]->getAlignment(); // [beam depth * max src length * current batch size] -> P(s|t); use alignments from the first scorer, even if ensemble,
@ -86,12 +85,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
// map wordIdx to word
auto prevBeamHypIdx = beamHypIdx; // back pointer
/*std::cerr << "currentBatchIdx=" << currentBatchIdx
<< " origBatchIdx=" << origBatchIdx
<< " beamHypIdx=" << beamHypIdx
<< " prevBeamHypIdx=" << prevBeamHypIdx
<< std::endl;*/
auto prevHyp = beam[prevBeamHypIdx];
Word word;
// If short list has been set, then wordIdx is an index into the short-listed word set,

View File

@ -63,6 +63,8 @@ public:
auto srcVocab = corpus_->getVocabs()[0];
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters");
if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) {
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabs.front() == vocabs.back());
}