Merged PR 22524: Optimize guided alignment training speed via sparse alignments - part 1

This replaces dense alignment storage and training with a sparse representation. Training speed with guided alignment matches now nearly normal training speed, regaining about 25% speed.

This is no. 1 of 2 PRs. The next one will introduce a new guided-alignment training scheme with better alignment accuracy.
This commit is contained in:
Marcin Junczys-Dowmunt 2022-02-11 13:50:47 +00:00
parent 3b21ff39c5
commit 4b51dcbd06
13 changed files with 138 additions and 113 deletions

View File

@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Fixed
### Changed
- Make guided-alignment faster via sparse memory layout, add alignment points for EOS, remove losses other than ce.
- Changed minimal C++ standard to C++-17
- Faster LSH top-k search on CPU

View File

@ -1 +1 @@
v1.11.2
v1.11.3

@ -1 +1 @@
Subproject commit 0716f4e012d1e3f7543bffa8aecc97ce9c903e17
Subproject commit d59f7ad85ecfdf4a788c095ac9fc1c447094e39e

View File

@ -510,7 +510,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"none");
cli.add<std::string>("--guided-alignment-cost",
"Cost type for guided alignment: ce (cross-entropy), mse (mean square error), mult (multiplication)",
"mse");
"ce");
cli.add<double>("--guided-alignment-weight",
"Weight for guided alignment cost",
0.1);

View File

@ -2,6 +2,8 @@
#include "common/utils.h"
#include <algorithm>
#include <cmath>
#include <set>
namespace marian {
namespace data {
@ -10,10 +12,11 @@ WordAlignment::WordAlignment() {}
WordAlignment::WordAlignment(const std::vector<Point>& align) : data_(align) {}
WordAlignment::WordAlignment(const std::string& line) {
WordAlignment::WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos) {
std::vector<std::string> atok = utils::splitAny(line, " -");
for(size_t i = 0; i < atok.size(); i += 2)
data_.emplace_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f });
data_.push_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f });
data_.push_back(Point{ srcEosPos, tgtEosPos, 1.f }); // add alignment point for both EOS symbols
}
void WordAlignment::sort() {
@ -22,6 +25,35 @@ void WordAlignment::sort() {
});
}
void WordAlignment::normalize(bool reverse/*=false*/) {
std::vector<size_t> counts;
counts.reserve(data_.size());
// reverse==false : normalize target word prob by number of source words
// reverse==true : normalize source word prob by number of target words
auto srcOrTgt = [](const Point& p, bool reverse) {
return reverse ? p.srcPos : p.tgtPos;
};
for(const auto& a : data_) {
size_t pos = srcOrTgt(a, reverse);
if(counts.size() <= pos)
counts.resize(pos + 1, 0);
counts[pos]++;
}
// a.prob at this point is either 1 or normalized to a different value,
// but we just set it to 1 / count, so multiple calls result in re-normalization
// regardless of forward or reverse direction. We also set the remaining values to 1.
for(auto& a : data_) {
size_t pos = srcOrTgt(a, reverse);
if(counts[pos] > 1)
a.prob = 1.f / counts[pos];
else
a.prob = 1.f;
}
}
std::string WordAlignment::toString() const {
std::stringstream str;
for(auto p = begin(); p != end(); ++p) {
@ -32,7 +64,7 @@ std::string WordAlignment::toString() const {
return str.str();
}
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft,
float threshold /*= 1.f*/) {
WordAlignment align;
// Alignments by maximum value
@ -58,7 +90,6 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
}
}
}
// Sort alignment pairs in ascending order
align.sort();

View File

@ -1,20 +1,22 @@
#pragma once
#include <sstream>
#include <tuple>
#include <vector>
namespace marian {
namespace data {
class WordAlignment {
struct Point
{
public:
struct Point {
size_t srcPos;
size_t tgtPos;
float prob;
};
private:
std::vector<Point> data_;
public:
WordAlignment();
@ -28,11 +30,14 @@ private:
public:
/**
* @brief Constructs word alignments from textual representation.
* @brief Constructs word alignments from textual representation. Adds alignment point for externally
* supplied EOS positions in source and target string.
*
* @param line String in the form of "0-0 1-1 1-2", etc.
*/
WordAlignment(const std::string& line);
WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos);
Point& operator[](size_t i) { return data_[i]; }
auto begin() const -> decltype(data_.begin()) { return data_.begin(); }
auto end() const -> decltype(data_.end()) { return data_.end(); }
@ -46,6 +51,12 @@ public:
*/
void sort();
/**
* @brief Normalizes alignment probabilities of target words to sum to 1 over source words alignments.
* This is needed for correct cost computation for guided alignment training with CE cost criterion.
*/
void normalize(bool reverse=false);
/**
* @brief Returns textual representation.
*/
@ -56,7 +67,7 @@ public:
// Also used on QuickSAND boundary where beam and batch size is 1. Then it is simply [t][s] -> P(s|t)
typedef std::vector<std::vector<float>> SoftAlignment; // [trg pos][beam depth * max src length * batch size]
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft,
float threshold = 1.f);
std::string SoftAlignToString(SoftAlignment align);

View File

@ -24,7 +24,7 @@ public:
const std::vector<size_t>& getSentenceIds() const { return sentenceIds_; }
void setSentenceIds(const std::vector<size_t>& ids) { sentenceIds_ = ids; }
virtual void setGuidedAlignment(std::vector<float>&&) = 0;
virtual void setGuidedAlignment(std::vector<WordAlignment>&&) = 0;
virtual void setDataWeights(const std::vector<float>&) = 0;
virtual ~Batch() {};
protected:

View File

@ -132,14 +132,13 @@ SentenceTuple Corpus::next() {
tup.markAltered();
addWordsToSentenceTuple(fields[i], vocabId, tup);
}
// weights are added last to the sentence tuple, because this runs a validation that needs
// length of the target sequence
if(alignFileIdx_ > -1)
addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
if(weightFileIdx_ > -1)
addWeightsToSentenceTuple(fields[weightFileIdx_], tup);
}
// weights are added last to the sentence tuple, because this runs a validation that needs
// length of the target sequence
if(alignFileIdx_ > -1)
addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
if(weightFileIdx_ > -1)
addWeightsToSentenceTuple(fields[weightFileIdx_], tup);
// check if all streams are valid, that is, non-empty and no longer than maximum allowed length
if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {

View File

@ -429,11 +429,13 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line,
void CorpusBase::addAlignmentToSentenceTuple(const std::string& line,
SentenceTupleImpl& tup) const {
ABORT_IF(rightLeft_,
"Guided alignment and right-left model cannot be used "
"together at the moment");
ABORT_IF(rightLeft_, "Guided alignment and right-left model cannot be used together at the moment");
ABORT_IF(tup.size() != 2, "Using alignment between source and target, but sentence tuple has {} elements??", tup.size());
auto align = WordAlignment(line);
size_t srcEosPos = tup[0].size() - 1;
size_t tgtEosPos = tup[1].size() - 1;
auto align = WordAlignment(line, srcEosPos, tgtEosPos);
tup.setAlignment(align);
}
@ -457,22 +459,17 @@ void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupl
void CorpusBase::addAlignmentsToBatch(Ptr<CorpusBatch> batch,
const std::vector<Sample>& batchVector) {
int srcWords = (int)batch->front()->batchWidth();
int trgWords = (int)batch->back()->batchWidth();
std::vector<WordAlignment> aligns;
int dimBatch = (int)batch->getSentenceIds().size();
std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f);
aligns.reserve(dimBatch);
for(int b = 0; b < dimBatch; ++b) {
// If the batch vector is altered within marian by, for example, case augmentation,
// the guided alignments we received for this tuple cease to be valid.
// Hence skip setting alignments for that sentence tuple..
if (!batchVector[b].isAltered()) {
for(auto p : batchVector[b].getAlignment()) {
size_t idx = p.srcPos * dimBatch * trgWords + b * trgWords + p.tgtPos;
aligns[idx] = 1.f;
}
aligns.push_back(std::move(batchVector[b].getAlignment()));
}
}
batch->setGuidedAlignment(std::move(aligns));

View File

@ -338,7 +338,7 @@ public:
class CorpusBatch : public Batch {
protected:
std::vector<Ptr<SubBatch>> subBatches_;
std::vector<float> guidedAlignment_; // [max source len, batch size, max target len] flattened
std::vector<WordAlignment> guidedAlignment_; // [max source len, batch size, max target len] flattened
std::vector<float> dataWeights_;
public:
@ -444,8 +444,17 @@ public:
if(options->get("guided-alignment", std::string("none")) != "none") {
// @TODO: if > 1 encoder, verify that all encoders have the same sentence lengths
std::vector<float> alignment(batchSize * lengths.front() * lengths.back(),
0.f);
std::vector<data::WordAlignment> alignment;
for(size_t k = 0; k < batchSize; ++k) {
data::WordAlignment perSentence;
// fill with random alignment points, add more twice the number of words to be safe.
for(size_t j = 0; j < lengths.back() * 2; ++j) {
size_t i = rand() % lengths.back();
perSentence.push_back(i, j, 1.0f);
}
alignment.push_back(std::move(perSentence));
}
batch->setGuidedAlignment(std::move(alignment));
}
@ -501,29 +510,14 @@ public:
}
if(!guidedAlignment_.empty()) {
size_t oldTrgWords = back()->batchWidth();
size_t oldSize = size();
pos = 0;
for(auto split : splits) {
auto cb = std::static_pointer_cast<CorpusBatch>(split);
size_t srcWords = cb->front()->batchWidth();
size_t trgWords = cb->back()->batchWidth();
size_t dimBatch = cb->size();
std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f);
for(size_t i = 0; i < dimBatch; ++i) {
size_t bi = i + pos;
for(size_t sid = 0; sid < srcWords; ++sid) {
for(size_t tid = 0; tid < trgWords; ++tid) {
size_t bidx = sid * oldSize * oldTrgWords + bi * oldTrgWords + tid; // [sid, bi, tid]
size_t idx = sid * dimBatch * trgWords + i * trgWords + tid;
aligns[idx] = guidedAlignment_[bidx];
}
}
}
cb->setGuidedAlignment(std::move(aligns));
std::vector<WordAlignment> batchAlignment;
for(size_t i = 0; i < dimBatch; ++i)
batchAlignment.push_back(std::move(guidedAlignment_[i + pos]));
cb->setGuidedAlignment(std::move(batchAlignment));
pos += dimBatch;
}
}
@ -556,15 +550,11 @@ public:
return splits;
}
const std::vector<float>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
void setGuidedAlignment(std::vector<float>&& aln) override {
const std::vector<WordAlignment>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
void setGuidedAlignment(std::vector<WordAlignment>&& aln) override {
guidedAlignment_ = std::move(aln);
}
size_t locateInGuidedAlignments(size_t b, size_t s, size_t t) {
return ((s * size()) + b) * widthTrg() + t;
}
std::vector<float>& getDataWeights() { return dataWeights_; }
void setDataWeights(const std::vector<float>& weights) override {
dataWeights_ = weights;

View File

@ -77,7 +77,7 @@ public:
size_t size() const override { return inputs_.front().shape()[0]; }
void setGuidedAlignment(std::vector<float>&&) override {
void setGuidedAlignment(std::vector<WordAlignment>&&) override {
ABORT("Guided alignment in DataBatch is not implemented");
}
void setDataWeights(const std::vector<float>&) override {

View File

@ -286,6 +286,8 @@ Expr operator/(float a, Expr b) {
/*********************************************************/
Expr concatenate(const std::vector<Expr>& concats, int ax) {
if(concats.size() == 1)
return concats[0];
return Expression<ConcatenateNodeOp>(concats, ax);
}

View File

@ -5,62 +5,57 @@
namespace marian {
static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> /*graph*/,
static inline const std::tuple<std::vector<IndexType>, std::vector<float>>
guidedAlignmentToSparse(Ptr<data::CorpusBatch> batch) {
int trgWords = (int)batch->back()->batchWidth();
int dimBatch = (int)batch->size();
typedef std::tuple<size_t, float> BiPoint;
std::vector<BiPoint> byIndex;
byIndex.reserve(dimBatch * trgWords);
for(size_t b = 0; b < dimBatch; ++b) {
auto guidedAlignmentFwd = batch->getGuidedAlignment()[b]; // this copies
guidedAlignmentFwd.normalize(/*reverse=*/false); // normalize forward
for(size_t i = 0; i < guidedAlignmentFwd.size(); ++i) {
auto pFwd = guidedAlignmentFwd[i];
IndexType idx = (IndexType)(pFwd.srcPos * dimBatch * trgWords + b * trgWords + pFwd.tgtPos);
byIndex.push_back({idx, pFwd.prob});
}
}
std::sort(byIndex.begin(), byIndex.end(), [](const BiPoint& a, const BiPoint& b) { return std::get<0>(a) < std::get<0>(b); });
std::vector<IndexType> indices; std::vector<float> valuesFwd;
indices.reserve(byIndex.size()); valuesFwd.reserve(byIndex.size());
for(auto& p : byIndex) {
indices.push_back((IndexType)std::get<0>(p));
valuesFwd.push_back(std::get<1>(p));
}
return {indices, valuesFwd};
}
static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
Ptr<Options> options,
Expr attention) { // [beam depth=1, max src length, batch size, tgt length]
std::string guidedLossType = options->get<std::string>("guided-alignment-cost"); // @TODO: change "cost" to "loss"
// We dropped support for other losses which are not possible to implement with sparse labels.
// They were most likely not used anyway.
ABORT_IF(guidedLossType != "ce", "Only alignment loss type 'ce' is supported");
float guidedLossWeight = options->get<float>("guided-alignment-weight");
const auto& shape = attention->shape(); // [beam depth=1, max src length, batch size, tgt length]
float epsilon = 1e-6f;
Expr alignmentLoss; // sum up loss over all attention/alignment positions
size_t numLabels;
if(guidedLossType == "ce") {
// normalizedAlignment is multi-hot, but ce requires normalized probabilities, so need to normalize to P(s|t)
auto dimBatch = shape[-2];
auto dimTrgWords = shape[-1];
auto dimSrcWords = shape[-3];
ABORT_IF(shape[-4] != 1, "Guided alignments with beam??");
auto normalizedAlignment = batch->getGuidedAlignment(); // [dimSrcWords, dimBatch, dimTrgWords] flattened, matches shape of 'attention'
auto srcBatch = batch->front();
const auto& srcMask = srcBatch->mask();
ABORT_IF(shape.elements() != normalizedAlignment.size(), "Attention-matrix and alignment shapes differ??");
ABORT_IF(dimBatch != batch->size() || dimTrgWords != batch->widthTrg() || dimSrcWords != batch->width(), "Attention-matrix and batch shapes differ??");
auto locate = [=](size_t s, size_t b, size_t t) { return ((s * dimBatch) + b) * dimTrgWords + t; };
for (size_t b = 0; b < dimBatch; b++) {
for (size_t t = 0; t < dimTrgWords; t++) {
for (size_t s = 0; s < dimSrcWords; s++)
ABORT_IF(locate(s, b, t) != batch->locateInGuidedAlignments(b, s, t), "locate() and locateInGuidedAlignments() differ??");
// renormalize the alignment such that it sums up to 1
float sum = 0;
for (size_t s = 0; s < dimSrcWords; s++)
sum += srcMask[srcBatch->locate(b, s)] * normalizedAlignment[locate(s, b, t)]; // these values are 0 or 1
if (sum != 0 && sum != 1)
for (size_t s = 0; s < dimSrcWords; s++)
normalizedAlignment[locate(s, b, t)] /= sum;
}
}
auto alignment = constant_like(attention, std::move(normalizedAlignment));
alignmentLoss = -sum(flatten(alignment * log(attention + epsilon)));
numLabels = batch->back()->batchWords();
ABORT_IF(numLabels > shape.elements() / shape[-3], "Num labels of guided alignment cost is off??");
} else {
auto alignment = constant_like(attention, batch->getGuidedAlignment());
if(guidedLossType == "mse")
alignmentLoss = sum(flatten(square(attention - alignment))) / 2.f;
else if(guidedLossType == "mult") // @TODO: I don't know what this criterion is for. Can we remove it?
alignmentLoss = -log(sum(flatten(attention * alignment)) + epsilon);
else
ABORT("Unknown alignment cost type: {}", guidedLossType);
// every position is a label as they should all agree
// @TODO: there should be positional masking here ... on the other hand, positions that are not
// in a sentence should always agree (both being 0). Lack of masking affects label count only which is
// probably negligible?
numLabels = shape.elements();
}
auto [indices, values] = guidedAlignmentToSparse(batch);
auto alignmentIndices = graph->indices(indices);
auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values));
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
float epsilon = 1e-6f;
Expr alignmentLoss = -sum(alignmentValues * log(attentionAtAligned + epsilon));
size_t numLabels = alignmentIndices->shape().elements();
// Create label node, also weigh by scalar so labels and cost are in the same domain.
// Fractional label counts are OK. But only if combined as "sum".
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires