mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
marcin's review changes
This commit is contained in:
parent
ff8af52624
commit
bd1f1ee9cb
@ -320,7 +320,8 @@ Ptr<ShortlistGenerator> createShortlistGenerator(Ptr<Options> options,
|
||||
size_t srcIdx,
|
||||
size_t trgIdx,
|
||||
bool shared) {
|
||||
if (lshOpts.size() == 2) {
|
||||
if (lshOpts.size()) {
|
||||
assert(lshOpts.size() == 2);
|
||||
size_t lemmaSize = trgVocab->lemmaSize();
|
||||
return New<LSHShortlistGenerator>(lshOpts[0], lshOpts[1], lemmaSize);
|
||||
}
|
||||
|
@ -66,8 +66,6 @@ public:
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////
|
||||
// implements SLIDE for faster inference.
|
||||
// https://arxiv.org/pdf/1903.03129.pdf
|
||||
class LSHShortlist: public Shortlist {
|
||||
private:
|
||||
int k_; // number of candidates returned from each input
|
||||
|
@ -62,7 +62,7 @@ Expr Logits::applyLossFunction(
|
||||
auto factorIndices = indices(maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
|
||||
auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
|
||||
auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
|
||||
std::cerr << "g=" << g << " factorLogits->loss()=" << factorLogits->loss()->shape() << std::endl;
|
||||
//std::cerr << "g=" << g << " factorLogits->loss()=" << factorLogits->loss()->shape() << std::endl;
|
||||
// For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
|
||||
auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
|
||||
// clang-format on
|
||||
@ -113,7 +113,7 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
|
||||
else {
|
||||
auto forward = [this, g](Expr out, const std::vector<Expr>& inputs) {
|
||||
Expr lastIndices = inputs[0];
|
||||
std::vector<float> masks = getFactorMasksMultiDim(g, lastIndices);
|
||||
std::vector<float> masks = getFactorMasks(g, lastIndices);
|
||||
out->val()->set(masks);
|
||||
};
|
||||
|
||||
@ -245,7 +245,7 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<float> Logits::getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr)
|
||||
std::vector<float> Logits::getFactorMasks(size_t factorGroup, Expr indicesExpr)
|
||||
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
|
||||
int batchSize = indicesExpr->shape()[0];
|
||||
int currBeamSize = indicesExpr->shape()[1];
|
||||
|
@ -77,7 +77,7 @@ private:
|
||||
} // actually the same as constant(data) for this data type
|
||||
std::vector<float> getFactorMasks(size_t factorGroup,
|
||||
const std::vector<WordIndex>& indices) const;
|
||||
std::vector<float> getFactorMasksMultiDim(size_t factorGroup, Expr indicesExpr) const;
|
||||
std::vector<float> getFactorMasks(size_t factorGroup, Expr indicesExpr) const; // same as above but separate indices for each batch and beam
|
||||
|
||||
private:
|
||||
// members
|
||||
|
@ -75,7 +75,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
|
||||
std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
|
||||
*/
|
||||
ABORT_IF(!(!transA && transB), "Must be transA==0 and transB==1");
|
||||
ABORT_IF(!(!transA && transB), "affineShortlist. Must be transA==0 and transB==1");
|
||||
ABORT_IF(b, "affineShortlist not tested with bias");
|
||||
Expr ret = bdot(x, W, transA, transB);
|
||||
//std::cerr << "ret=" << ret->shape() << std::endl;
|
||||
@ -83,8 +83,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
return ret;
|
||||
};
|
||||
|
||||
if(shortlist_) { // shortlisted versions of parameters are cached within one
|
||||
// batch, then clear()ed
|
||||
if(shortlist_) {
|
||||
shortlist_->filter(input, Wt_, isLegacyUntransposedW, b_, lemmaEt_);
|
||||
}
|
||||
|
||||
|
@ -20,7 +20,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
|
||||
const std::vector<bool>& dropBatchEntries, // [origDimBatch] - empty source batch entries are marked with true, should be cleared after first use.
|
||||
const std::vector<IndexType>& batchIdxMap) const { // [origBatchIdx -> currentBatchIdx]
|
||||
std::vector<float> align; // collects alignment information from the last executed time step
|
||||
//utils::Debug(batchIdxMap, "batchIdxMap");
|
||||
if(options_->hasAndNotEmpty("alignment") && factorGroup == 0)
|
||||
align = scorers_[0]->getAlignment(); // [beam depth * max src length * current batch size] -> P(s|t); use alignments from the first scorer, even if ensemble,
|
||||
|
||||
@ -86,12 +85,6 @@ Beams BeamSearch::toHyps(const std::vector<unsigned int>& nBestKeys, // [current
|
||||
|
||||
// map wordIdx to word
|
||||
auto prevBeamHypIdx = beamHypIdx; // back pointer
|
||||
/*std::cerr << "currentBatchIdx=" << currentBatchIdx
|
||||
<< " origBatchIdx=" << origBatchIdx
|
||||
<< " beamHypIdx=" << beamHypIdx
|
||||
<< " prevBeamHypIdx=" << prevBeamHypIdx
|
||||
<< std::endl;*/
|
||||
|
||||
auto prevHyp = beam[prevBeamHypIdx];
|
||||
Word word;
|
||||
// If short list has been set, then wordIdx is an index into the short-listed word set,
|
||||
|
@ -63,6 +63,8 @@ public:
|
||||
auto srcVocab = corpus_->getVocabs()[0];
|
||||
|
||||
std::vector<int> lshOpts = options_->get<std::vector<int>>("output-approx-knn");
|
||||
ABORT_IF(lshOpts.size() != 0 && lshOpts.size() != 2, "--output-approx-knn takes 2 parameters");
|
||||
|
||||
if (lshOpts.size() == 2 || options_->hasAndNotEmpty("shortlist")) {
|
||||
shortlistGenerator_ = data::createShortlistGenerator(options_, srcVocab, trgVocab_, lshOpts, 0, 1, vocabs.front() == vocabs.back());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user