From 4b51dcbd066b927e29e4007512c1f887ee1a350f Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Fri, 11 Feb 2022 13:50:47 +0000 Subject: [PATCH] Merged PR 22524: Optimize guided alignment training speed via sparse alignments - part 1 This replaces dense alignment storage and training with a sparse representation. Training speed with guided alignment matches now nearly normal training speed, regaining about 25% speed. This is no. 1 of 2 PRs. The next one will introduce a new guided-alignment training scheme with better alignment accuracy. --- CHANGELOG.md | 2 +- VERSION | 2 +- regression-tests | 2 +- src/common/config_parser.cpp | 2 +- src/data/alignment.cpp | 39 +++++++++++-- src/data/alignment.h | 21 +++++-- src/data/batch.h | 2 +- src/data/corpus.cpp | 13 ++--- src/data/corpus_base.cpp | 25 ++++---- src/data/corpus_base.h | 46 ++++++--------- src/examples/mnist/dataset.h | 2 +- src/graph/expression_operators.cpp | 2 + src/layers/guided_alignment.h | 93 ++++++++++++++---------------- 13 files changed, 138 insertions(+), 113 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a724d254..ad4642f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Fixed ### Changed - +- Make guided-alignment faster via sparse memory layout, add alignment points for EOS, remove losses other than ce. - Changed minimal C++ standard to C++-17 - Faster LSH top-k search on CPU diff --git a/VERSION b/VERSION index 07fb54b5..3d461ead 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.11.2 +v1.11.3 diff --git a/regression-tests b/regression-tests index 0716f4e0..d59f7ad8 160000 --- a/regression-tests +++ b/regression-tests @@ -1 +1 @@ -Subproject commit 0716f4e012d1e3f7543bffa8aecc97ce9c903e17 +Subproject commit d59f7ad85ecfdf4a788c095ac9fc1c447094e39e diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 837bee53..0d956495 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -510,7 +510,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) { "none"); cli.add("--guided-alignment-cost", "Cost type for guided alignment: ce (cross-entropy), mse (mean square error), mult (multiplication)", - "mse"); + "ce"); cli.add("--guided-alignment-weight", "Weight for guided alignment cost", 0.1); diff --git a/src/data/alignment.cpp b/src/data/alignment.cpp index 928beb21..3b7e0d66 100644 --- a/src/data/alignment.cpp +++ b/src/data/alignment.cpp @@ -2,6 +2,8 @@ #include "common/utils.h" #include +#include +#include namespace marian { namespace data { @@ -10,10 +12,11 @@ WordAlignment::WordAlignment() {} WordAlignment::WordAlignment(const std::vector& align) : data_(align) {} -WordAlignment::WordAlignment(const std::string& line) { +WordAlignment::WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos) { std::vector atok = utils::splitAny(line, " -"); for(size_t i = 0; i < atok.size(); i += 2) - data_.emplace_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f }); + data_.push_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f }); + data_.push_back(Point{ srcEosPos, tgtEosPos, 1.f }); // add alignment point for both EOS symbols } void WordAlignment::sort() { @@ -22,6 +25,35 @@ void WordAlignment::sort() { }); } +void WordAlignment::normalize(bool reverse/*=false*/) { + std::vector counts; + counts.reserve(data_.size()); + + // reverse==false : normalize target word prob by number of source words + // reverse==true : normalize source word prob by number of target words + auto srcOrTgt = [](const Point& p, bool reverse) { + return reverse ? p.srcPos : p.tgtPos; + }; + + for(const auto& a : data_) { + size_t pos = srcOrTgt(a, reverse); + if(counts.size() <= pos) + counts.resize(pos + 1, 0); + counts[pos]++; + } + + // a.prob at this point is either 1 or normalized to a different value, + // but we just set it to 1 / count, so multiple calls result in re-normalization + // regardless of forward or reverse direction. We also set the remaining values to 1. + for(auto& a : data_) { + size_t pos = srcOrTgt(a, reverse); + if(counts[pos] > 1) + a.prob = 1.f / counts[pos]; + else + a.prob = 1.f; + } +} + std::string WordAlignment::toString() const { std::stringstream str; for(auto p = begin(); p != end(); ++p) { @@ -32,7 +64,7 @@ std::string WordAlignment::toString() const { return str.str(); } -WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, +WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft, float threshold /*= 1.f*/) { WordAlignment align; // Alignments by maximum value @@ -58,7 +90,6 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, } } } - // Sort alignment pairs in ascending order align.sort(); diff --git a/src/data/alignment.h b/src/data/alignment.h index 1c68bb39..f27bea38 100644 --- a/src/data/alignment.h +++ b/src/data/alignment.h @@ -1,20 +1,22 @@ #pragma once #include +#include #include namespace marian { namespace data { class WordAlignment { - struct Point - { +public: + struct Point { size_t srcPos; size_t tgtPos; float prob; }; private: std::vector data_; + public: WordAlignment(); @@ -28,11 +30,14 @@ private: public: /** - * @brief Constructs word alignments from textual representation. + * @brief Constructs word alignments from textual representation. Adds alignment point for externally + * supplied EOS positions in source and target string. * * @param line String in the form of "0-0 1-1 1-2", etc. */ - WordAlignment(const std::string& line); + WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos); + + Point& operator[](size_t i) { return data_[i]; } auto begin() const -> decltype(data_.begin()) { return data_.begin(); } auto end() const -> decltype(data_.end()) { return data_.end(); } @@ -46,6 +51,12 @@ public: */ void sort(); + /** + * @brief Normalizes alignment probabilities of target words to sum to 1 over source words alignments. + * This is needed for correct cost computation for guided alignment training with CE cost criterion. + */ + void normalize(bool reverse=false); + /** * @brief Returns textual representation. */ @@ -56,7 +67,7 @@ public: // Also used on QuickSAND boundary where beam and batch size is 1. Then it is simply [t][s] -> P(s|t) typedef std::vector> SoftAlignment; // [trg pos][beam depth * max src length * batch size] -WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft, +WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft, float threshold = 1.f); std::string SoftAlignToString(SoftAlignment align); diff --git a/src/data/batch.h b/src/data/batch.h index 3c592b31..761f46a4 100644 --- a/src/data/batch.h +++ b/src/data/batch.h @@ -24,7 +24,7 @@ public: const std::vector& getSentenceIds() const { return sentenceIds_; } void setSentenceIds(const std::vector& ids) { sentenceIds_ = ids; } - virtual void setGuidedAlignment(std::vector&&) = 0; + virtual void setGuidedAlignment(std::vector&&) = 0; virtual void setDataWeights(const std::vector&) = 0; virtual ~Batch() {}; protected: diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp index 643a7de9..2fbe4982 100644 --- a/src/data/corpus.cpp +++ b/src/data/corpus.cpp @@ -132,14 +132,13 @@ SentenceTuple Corpus::next() { tup.markAltered(); addWordsToSentenceTuple(fields[i], vocabId, tup); } - - // weights are added last to the sentence tuple, because this runs a validation that needs - // length of the target sequence - if(alignFileIdx_ > -1) - addAlignmentToSentenceTuple(fields[alignFileIdx_], tup); - if(weightFileIdx_ > -1) - addWeightsToSentenceTuple(fields[weightFileIdx_], tup); } + // weights are added last to the sentence tuple, because this runs a validation that needs + // length of the target sequence + if(alignFileIdx_ > -1) + addAlignmentToSentenceTuple(fields[alignFileIdx_], tup); + if(weightFileIdx_ > -1) + addWeightsToSentenceTuple(fields[weightFileIdx_], tup); // check if all streams are valid, that is, non-empty and no longer than maximum allowed length if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) { diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp index 636752c9..71c9f990 100644 --- a/src/data/corpus_base.cpp +++ b/src/data/corpus_base.cpp @@ -429,11 +429,13 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line, void CorpusBase::addAlignmentToSentenceTuple(const std::string& line, SentenceTupleImpl& tup) const { - ABORT_IF(rightLeft_, - "Guided alignment and right-left model cannot be used " - "together at the moment"); + ABORT_IF(rightLeft_, "Guided alignment and right-left model cannot be used together at the moment"); + ABORT_IF(tup.size() != 2, "Using alignment between source and target, but sentence tuple has {} elements??", tup.size()); - auto align = WordAlignment(line); + size_t srcEosPos = tup[0].size() - 1; + size_t tgtEosPos = tup[1].size() - 1; + + auto align = WordAlignment(line, srcEosPos, tgtEosPos); tup.setAlignment(align); } @@ -457,22 +459,17 @@ void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupl void CorpusBase::addAlignmentsToBatch(Ptr batch, const std::vector& batchVector) { - int srcWords = (int)batch->front()->batchWidth(); - int trgWords = (int)batch->back()->batchWidth(); + std::vector aligns; + int dimBatch = (int)batch->getSentenceIds().size(); - - std::vector aligns(srcWords * dimBatch * trgWords, 0.f); - + aligns.reserve(dimBatch); + for(int b = 0; b < dimBatch; ++b) { - // If the batch vector is altered within marian by, for example, case augmentation, // the guided alignments we received for this tuple cease to be valid. // Hence skip setting alignments for that sentence tuple.. if (!batchVector[b].isAltered()) { - for(auto p : batchVector[b].getAlignment()) { - size_t idx = p.srcPos * dimBatch * trgWords + b * trgWords + p.tgtPos; - aligns[idx] = 1.f; - } + aligns.push_back(std::move(batchVector[b].getAlignment())); } } batch->setGuidedAlignment(std::move(aligns)); diff --git a/src/data/corpus_base.h b/src/data/corpus_base.h index a54c20f8..4e6d923e 100644 --- a/src/data/corpus_base.h +++ b/src/data/corpus_base.h @@ -338,7 +338,7 @@ public: class CorpusBatch : public Batch { protected: std::vector> subBatches_; - std::vector guidedAlignment_; // [max source len, batch size, max target len] flattened + std::vector guidedAlignment_; // [max source len, batch size, max target len] flattened std::vector dataWeights_; public: @@ -444,8 +444,17 @@ public: if(options->get("guided-alignment", std::string("none")) != "none") { // @TODO: if > 1 encoder, verify that all encoders have the same sentence lengths - std::vector alignment(batchSize * lengths.front() * lengths.back(), - 0.f); + + std::vector alignment; + for(size_t k = 0; k < batchSize; ++k) { + data::WordAlignment perSentence; + // fill with random alignment points, add more twice the number of words to be safe. + for(size_t j = 0; j < lengths.back() * 2; ++j) { + size_t i = rand() % lengths.back(); + perSentence.push_back(i, j, 1.0f); + } + alignment.push_back(std::move(perSentence)); + } batch->setGuidedAlignment(std::move(alignment)); } @@ -501,29 +510,14 @@ public: } if(!guidedAlignment_.empty()) { - size_t oldTrgWords = back()->batchWidth(); - size_t oldSize = size(); - pos = 0; for(auto split : splits) { auto cb = std::static_pointer_cast(split); - size_t srcWords = cb->front()->batchWidth(); - size_t trgWords = cb->back()->batchWidth(); size_t dimBatch = cb->size(); - - std::vector aligns(srcWords * dimBatch * trgWords, 0.f); - - for(size_t i = 0; i < dimBatch; ++i) { - size_t bi = i + pos; - for(size_t sid = 0; sid < srcWords; ++sid) { - for(size_t tid = 0; tid < trgWords; ++tid) { - size_t bidx = sid * oldSize * oldTrgWords + bi * oldTrgWords + tid; // [sid, bi, tid] - size_t idx = sid * dimBatch * trgWords + i * trgWords + tid; - aligns[idx] = guidedAlignment_[bidx]; - } - } - } - cb->setGuidedAlignment(std::move(aligns)); + std::vector batchAlignment; + for(size_t i = 0; i < dimBatch; ++i) + batchAlignment.push_back(std::move(guidedAlignment_[i + pos])); + cb->setGuidedAlignment(std::move(batchAlignment)); pos += dimBatch; } } @@ -556,15 +550,11 @@ public: return splits; } - const std::vector& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened - void setGuidedAlignment(std::vector&& aln) override { + const std::vector& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened + void setGuidedAlignment(std::vector&& aln) override { guidedAlignment_ = std::move(aln); } - size_t locateInGuidedAlignments(size_t b, size_t s, size_t t) { - return ((s * size()) + b) * widthTrg() + t; - } - std::vector& getDataWeights() { return dataWeights_; } void setDataWeights(const std::vector& weights) override { dataWeights_ = weights; diff --git a/src/examples/mnist/dataset.h b/src/examples/mnist/dataset.h index b0492b85..c665fa65 100644 --- a/src/examples/mnist/dataset.h +++ b/src/examples/mnist/dataset.h @@ -77,7 +77,7 @@ public: size_t size() const override { return inputs_.front().shape()[0]; } - void setGuidedAlignment(std::vector&&) override { + void setGuidedAlignment(std::vector&&) override { ABORT("Guided alignment in DataBatch is not implemented"); } void setDataWeights(const std::vector&) override { diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 5294fca3..ca5e6805 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -286,6 +286,8 @@ Expr operator/(float a, Expr b) { /*********************************************************/ Expr concatenate(const std::vector& concats, int ax) { + if(concats.size() == 1) + return concats[0]; return Expression(concats, ax); } diff --git a/src/layers/guided_alignment.h b/src/layers/guided_alignment.h index f08d3f09..d2171c50 100644 --- a/src/layers/guided_alignment.h +++ b/src/layers/guided_alignment.h @@ -5,62 +5,57 @@ namespace marian { -static inline RationalLoss guidedAlignmentCost(Ptr /*graph*/, +static inline const std::tuple, std::vector> +guidedAlignmentToSparse(Ptr batch) { + int trgWords = (int)batch->back()->batchWidth(); + int dimBatch = (int)batch->size(); + + typedef std::tuple BiPoint; + std::vector byIndex; + byIndex.reserve(dimBatch * trgWords); + + for(size_t b = 0; b < dimBatch; ++b) { + auto guidedAlignmentFwd = batch->getGuidedAlignment()[b]; // this copies + guidedAlignmentFwd.normalize(/*reverse=*/false); // normalize forward + for(size_t i = 0; i < guidedAlignmentFwd.size(); ++i) { + auto pFwd = guidedAlignmentFwd[i]; + IndexType idx = (IndexType)(pFwd.srcPos * dimBatch * trgWords + b * trgWords + pFwd.tgtPos); + byIndex.push_back({idx, pFwd.prob}); + } + } + + std::sort(byIndex.begin(), byIndex.end(), [](const BiPoint& a, const BiPoint& b) { return std::get<0>(a) < std::get<0>(b); }); + std::vector indices; std::vector valuesFwd; + indices.reserve(byIndex.size()); valuesFwd.reserve(byIndex.size()); + for(auto& p : byIndex) { + indices.push_back((IndexType)std::get<0>(p)); + valuesFwd.push_back(std::get<1>(p)); + } + + return {indices, valuesFwd}; +} + +static inline RationalLoss guidedAlignmentCost(Ptr graph, Ptr batch, Ptr options, Expr attention) { // [beam depth=1, max src length, batch size, tgt length] - std::string guidedLossType = options->get("guided-alignment-cost"); // @TODO: change "cost" to "loss" + + // We dropped support for other losses which are not possible to implement with sparse labels. + // They were most likely not used anyway. + ABORT_IF(guidedLossType != "ce", "Only alignment loss type 'ce' is supported"); + float guidedLossWeight = options->get("guided-alignment-weight"); - const auto& shape = attention->shape(); // [beam depth=1, max src length, batch size, tgt length] - float epsilon = 1e-6f; - Expr alignmentLoss; // sum up loss over all attention/alignment positions - size_t numLabels; - if(guidedLossType == "ce") { - // normalizedAlignment is multi-hot, but ce requires normalized probabilities, so need to normalize to P(s|t) - auto dimBatch = shape[-2]; - auto dimTrgWords = shape[-1]; - auto dimSrcWords = shape[-3]; - ABORT_IF(shape[-4] != 1, "Guided alignments with beam??"); - auto normalizedAlignment = batch->getGuidedAlignment(); // [dimSrcWords, dimBatch, dimTrgWords] flattened, matches shape of 'attention' - auto srcBatch = batch->front(); - const auto& srcMask = srcBatch->mask(); - ABORT_IF(shape.elements() != normalizedAlignment.size(), "Attention-matrix and alignment shapes differ??"); - ABORT_IF(dimBatch != batch->size() || dimTrgWords != batch->widthTrg() || dimSrcWords != batch->width(), "Attention-matrix and batch shapes differ??"); - auto locate = [=](size_t s, size_t b, size_t t) { return ((s * dimBatch) + b) * dimTrgWords + t; }; - for (size_t b = 0; b < dimBatch; b++) { - for (size_t t = 0; t < dimTrgWords; t++) { - for (size_t s = 0; s < dimSrcWords; s++) - ABORT_IF(locate(s, b, t) != batch->locateInGuidedAlignments(b, s, t), "locate() and locateInGuidedAlignments() differ??"); - // renormalize the alignment such that it sums up to 1 - float sum = 0; - for (size_t s = 0; s < dimSrcWords; s++) - sum += srcMask[srcBatch->locate(b, s)] * normalizedAlignment[locate(s, b, t)]; // these values are 0 or 1 - if (sum != 0 && sum != 1) - for (size_t s = 0; s < dimSrcWords; s++) - normalizedAlignment[locate(s, b, t)] /= sum; - } - } - auto alignment = constant_like(attention, std::move(normalizedAlignment)); - alignmentLoss = -sum(flatten(alignment * log(attention + epsilon))); - numLabels = batch->back()->batchWords(); - ABORT_IF(numLabels > shape.elements() / shape[-3], "Num labels of guided alignment cost is off??"); - } else { - auto alignment = constant_like(attention, batch->getGuidedAlignment()); - if(guidedLossType == "mse") - alignmentLoss = sum(flatten(square(attention - alignment))) / 2.f; - else if(guidedLossType == "mult") // @TODO: I don't know what this criterion is for. Can we remove it? - alignmentLoss = -log(sum(flatten(attention * alignment)) + epsilon); - else - ABORT("Unknown alignment cost type: {}", guidedLossType); - // every position is a label as they should all agree - // @TODO: there should be positional masking here ... on the other hand, positions that are not - // in a sentence should always agree (both being 0). Lack of masking affects label count only which is - // probably negligible? - numLabels = shape.elements(); - } + auto [indices, values] = guidedAlignmentToSparse(batch); + auto alignmentIndices = graph->indices(indices); + auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values)); + auto attentionAtAligned = cols(flatten(attention), alignmentIndices); + float epsilon = 1e-6f; + Expr alignmentLoss = -sum(alignmentValues * log(attentionAtAligned + epsilon)); + size_t numLabels = alignmentIndices->shape().elements(); + // Create label node, also weigh by scalar so labels and cost are in the same domain. // Fractional label counts are OK. But only if combined as "sum". // @TODO: It is ugly to check the multi-loss type here, but doing this right requires