mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 22524: Optimize guided alignment training speed via sparse alignments - part 1
This replaces dense alignment storage and training with a sparse representation. Training speed with guided alignment matches now nearly normal training speed, regaining about 25% speed. This is no. 1 of 2 PRs. The next one will introduce a new guided-alignment training scheme with better alignment accuracy.
This commit is contained in:
parent
3b21ff39c5
commit
4b51dcbd06
@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
### Fixed
|
||||
|
||||
### Changed
|
||||
|
||||
- Make guided-alignment faster via sparse memory layout, add alignment points for EOS, remove losses other than ce.
|
||||
- Changed minimal C++ standard to C++-17
|
||||
- Faster LSH top-k search on CPU
|
||||
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit 0716f4e012d1e3f7543bffa8aecc97ce9c903e17
|
||||
Subproject commit d59f7ad85ecfdf4a788c095ac9fc1c447094e39e
|
@ -510,7 +510,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
||||
"none");
|
||||
cli.add<std::string>("--guided-alignment-cost",
|
||||
"Cost type for guided alignment: ce (cross-entropy), mse (mean square error), mult (multiplication)",
|
||||
"mse");
|
||||
"ce");
|
||||
cli.add<double>("--guided-alignment-weight",
|
||||
"Weight for guided alignment cost",
|
||||
0.1);
|
||||
|
@ -2,6 +2,8 @@
|
||||
#include "common/utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <set>
|
||||
|
||||
namespace marian {
|
||||
namespace data {
|
||||
@ -10,10 +12,11 @@ WordAlignment::WordAlignment() {}
|
||||
|
||||
WordAlignment::WordAlignment(const std::vector<Point>& align) : data_(align) {}
|
||||
|
||||
WordAlignment::WordAlignment(const std::string& line) {
|
||||
WordAlignment::WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos) {
|
||||
std::vector<std::string> atok = utils::splitAny(line, " -");
|
||||
for(size_t i = 0; i < atok.size(); i += 2)
|
||||
data_.emplace_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f });
|
||||
data_.push_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f });
|
||||
data_.push_back(Point{ srcEosPos, tgtEosPos, 1.f }); // add alignment point for both EOS symbols
|
||||
}
|
||||
|
||||
void WordAlignment::sort() {
|
||||
@ -22,6 +25,35 @@ void WordAlignment::sort() {
|
||||
});
|
||||
}
|
||||
|
||||
void WordAlignment::normalize(bool reverse/*=false*/) {
|
||||
std::vector<size_t> counts;
|
||||
counts.reserve(data_.size());
|
||||
|
||||
// reverse==false : normalize target word prob by number of source words
|
||||
// reverse==true : normalize source word prob by number of target words
|
||||
auto srcOrTgt = [](const Point& p, bool reverse) {
|
||||
return reverse ? p.srcPos : p.tgtPos;
|
||||
};
|
||||
|
||||
for(const auto& a : data_) {
|
||||
size_t pos = srcOrTgt(a, reverse);
|
||||
if(counts.size() <= pos)
|
||||
counts.resize(pos + 1, 0);
|
||||
counts[pos]++;
|
||||
}
|
||||
|
||||
// a.prob at this point is either 1 or normalized to a different value,
|
||||
// but we just set it to 1 / count, so multiple calls result in re-normalization
|
||||
// regardless of forward or reverse direction. We also set the remaining values to 1.
|
||||
for(auto& a : data_) {
|
||||
size_t pos = srcOrTgt(a, reverse);
|
||||
if(counts[pos] > 1)
|
||||
a.prob = 1.f / counts[pos];
|
||||
else
|
||||
a.prob = 1.f;
|
||||
}
|
||||
}
|
||||
|
||||
std::string WordAlignment::toString() const {
|
||||
std::stringstream str;
|
||||
for(auto p = begin(); p != end(); ++p) {
|
||||
@ -32,7 +64,7 @@ std::string WordAlignment::toString() const {
|
||||
return str.str();
|
||||
}
|
||||
|
||||
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||
WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft,
|
||||
float threshold /*= 1.f*/) {
|
||||
WordAlignment align;
|
||||
// Alignments by maximum value
|
||||
@ -58,7 +90,6 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort alignment pairs in ascending order
|
||||
align.sort();
|
||||
|
||||
|
@ -1,20 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace marian {
|
||||
namespace data {
|
||||
|
||||
class WordAlignment {
|
||||
struct Point
|
||||
{
|
||||
public:
|
||||
struct Point {
|
||||
size_t srcPos;
|
||||
size_t tgtPos;
|
||||
float prob;
|
||||
};
|
||||
private:
|
||||
std::vector<Point> data_;
|
||||
|
||||
public:
|
||||
WordAlignment();
|
||||
|
||||
@ -28,11 +30,14 @@ private:
|
||||
public:
|
||||
|
||||
/**
|
||||
* @brief Constructs word alignments from textual representation.
|
||||
* @brief Constructs word alignments from textual representation. Adds alignment point for externally
|
||||
* supplied EOS positions in source and target string.
|
||||
*
|
||||
* @param line String in the form of "0-0 1-1 1-2", etc.
|
||||
*/
|
||||
WordAlignment(const std::string& line);
|
||||
WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos);
|
||||
|
||||
Point& operator[](size_t i) { return data_[i]; }
|
||||
|
||||
auto begin() const -> decltype(data_.begin()) { return data_.begin(); }
|
||||
auto end() const -> decltype(data_.end()) { return data_.end(); }
|
||||
@ -46,6 +51,12 @@ public:
|
||||
*/
|
||||
void sort();
|
||||
|
||||
/**
|
||||
* @brief Normalizes alignment probabilities of target words to sum to 1 over source words alignments.
|
||||
* This is needed for correct cost computation for guided alignment training with CE cost criterion.
|
||||
*/
|
||||
void normalize(bool reverse=false);
|
||||
|
||||
/**
|
||||
* @brief Returns textual representation.
|
||||
*/
|
||||
@ -56,7 +67,7 @@ public:
|
||||
// Also used on QuickSAND boundary where beam and batch size is 1. Then it is simply [t][s] -> P(s|t)
|
||||
typedef std::vector<std::vector<float>> SoftAlignment; // [trg pos][beam depth * max src length * batch size]
|
||||
|
||||
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
||||
WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft,
|
||||
float threshold = 1.f);
|
||||
|
||||
std::string SoftAlignToString(SoftAlignment align);
|
||||
|
@ -24,7 +24,7 @@ public:
|
||||
const std::vector<size_t>& getSentenceIds() const { return sentenceIds_; }
|
||||
void setSentenceIds(const std::vector<size_t>& ids) { sentenceIds_ = ids; }
|
||||
|
||||
virtual void setGuidedAlignment(std::vector<float>&&) = 0;
|
||||
virtual void setGuidedAlignment(std::vector<WordAlignment>&&) = 0;
|
||||
virtual void setDataWeights(const std::vector<float>&) = 0;
|
||||
virtual ~Batch() {};
|
||||
protected:
|
||||
|
@ -132,14 +132,13 @@ SentenceTuple Corpus::next() {
|
||||
tup.markAltered();
|
||||
addWordsToSentenceTuple(fields[i], vocabId, tup);
|
||||
}
|
||||
|
||||
}
|
||||
// weights are added last to the sentence tuple, because this runs a validation that needs
|
||||
// length of the target sequence
|
||||
if(alignFileIdx_ > -1)
|
||||
addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
|
||||
if(weightFileIdx_ > -1)
|
||||
addWeightsToSentenceTuple(fields[weightFileIdx_], tup);
|
||||
}
|
||||
|
||||
// check if all streams are valid, that is, non-empty and no longer than maximum allowed length
|
||||
if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
|
||||
|
@ -429,11 +429,13 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line,
|
||||
|
||||
void CorpusBase::addAlignmentToSentenceTuple(const std::string& line,
|
||||
SentenceTupleImpl& tup) const {
|
||||
ABORT_IF(rightLeft_,
|
||||
"Guided alignment and right-left model cannot be used "
|
||||
"together at the moment");
|
||||
ABORT_IF(rightLeft_, "Guided alignment and right-left model cannot be used together at the moment");
|
||||
ABORT_IF(tup.size() != 2, "Using alignment between source and target, but sentence tuple has {} elements??", tup.size());
|
||||
|
||||
auto align = WordAlignment(line);
|
||||
size_t srcEosPos = tup[0].size() - 1;
|
||||
size_t tgtEosPos = tup[1].size() - 1;
|
||||
|
||||
auto align = WordAlignment(line, srcEosPos, tgtEosPos);
|
||||
tup.setAlignment(align);
|
||||
}
|
||||
|
||||
@ -457,22 +459,17 @@ void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupl
|
||||
|
||||
void CorpusBase::addAlignmentsToBatch(Ptr<CorpusBatch> batch,
|
||||
const std::vector<Sample>& batchVector) {
|
||||
int srcWords = (int)batch->front()->batchWidth();
|
||||
int trgWords = (int)batch->back()->batchWidth();
|
||||
int dimBatch = (int)batch->getSentenceIds().size();
|
||||
std::vector<WordAlignment> aligns;
|
||||
|
||||
std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f);
|
||||
int dimBatch = (int)batch->getSentenceIds().size();
|
||||
aligns.reserve(dimBatch);
|
||||
|
||||
for(int b = 0; b < dimBatch; ++b) {
|
||||
|
||||
// If the batch vector is altered within marian by, for example, case augmentation,
|
||||
// the guided alignments we received for this tuple cease to be valid.
|
||||
// Hence skip setting alignments for that sentence tuple..
|
||||
if (!batchVector[b].isAltered()) {
|
||||
for(auto p : batchVector[b].getAlignment()) {
|
||||
size_t idx = p.srcPos * dimBatch * trgWords + b * trgWords + p.tgtPos;
|
||||
aligns[idx] = 1.f;
|
||||
}
|
||||
aligns.push_back(std::move(batchVector[b].getAlignment()));
|
||||
}
|
||||
}
|
||||
batch->setGuidedAlignment(std::move(aligns));
|
||||
|
@ -338,7 +338,7 @@ public:
|
||||
class CorpusBatch : public Batch {
|
||||
protected:
|
||||
std::vector<Ptr<SubBatch>> subBatches_;
|
||||
std::vector<float> guidedAlignment_; // [max source len, batch size, max target len] flattened
|
||||
std::vector<WordAlignment> guidedAlignment_; // [max source len, batch size, max target len] flattened
|
||||
std::vector<float> dataWeights_;
|
||||
|
||||
public:
|
||||
@ -444,8 +444,17 @@ public:
|
||||
|
||||
if(options->get("guided-alignment", std::string("none")) != "none") {
|
||||
// @TODO: if > 1 encoder, verify that all encoders have the same sentence lengths
|
||||
std::vector<float> alignment(batchSize * lengths.front() * lengths.back(),
|
||||
0.f);
|
||||
|
||||
std::vector<data::WordAlignment> alignment;
|
||||
for(size_t k = 0; k < batchSize; ++k) {
|
||||
data::WordAlignment perSentence;
|
||||
// fill with random alignment points, add more twice the number of words to be safe.
|
||||
for(size_t j = 0; j < lengths.back() * 2; ++j) {
|
||||
size_t i = rand() % lengths.back();
|
||||
perSentence.push_back(i, j, 1.0f);
|
||||
}
|
||||
alignment.push_back(std::move(perSentence));
|
||||
}
|
||||
batch->setGuidedAlignment(std::move(alignment));
|
||||
}
|
||||
|
||||
@ -501,29 +510,14 @@ public:
|
||||
}
|
||||
|
||||
if(!guidedAlignment_.empty()) {
|
||||
size_t oldTrgWords = back()->batchWidth();
|
||||
size_t oldSize = size();
|
||||
|
||||
pos = 0;
|
||||
for(auto split : splits) {
|
||||
auto cb = std::static_pointer_cast<CorpusBatch>(split);
|
||||
size_t srcWords = cb->front()->batchWidth();
|
||||
size_t trgWords = cb->back()->batchWidth();
|
||||
size_t dimBatch = cb->size();
|
||||
|
||||
std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f);
|
||||
|
||||
for(size_t i = 0; i < dimBatch; ++i) {
|
||||
size_t bi = i + pos;
|
||||
for(size_t sid = 0; sid < srcWords; ++sid) {
|
||||
for(size_t tid = 0; tid < trgWords; ++tid) {
|
||||
size_t bidx = sid * oldSize * oldTrgWords + bi * oldTrgWords + tid; // [sid, bi, tid]
|
||||
size_t idx = sid * dimBatch * trgWords + i * trgWords + tid;
|
||||
aligns[idx] = guidedAlignment_[bidx];
|
||||
}
|
||||
}
|
||||
}
|
||||
cb->setGuidedAlignment(std::move(aligns));
|
||||
std::vector<WordAlignment> batchAlignment;
|
||||
for(size_t i = 0; i < dimBatch; ++i)
|
||||
batchAlignment.push_back(std::move(guidedAlignment_[i + pos]));
|
||||
cb->setGuidedAlignment(std::move(batchAlignment));
|
||||
pos += dimBatch;
|
||||
}
|
||||
}
|
||||
@ -556,15 +550,11 @@ public:
|
||||
return splits;
|
||||
}
|
||||
|
||||
const std::vector<float>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
|
||||
void setGuidedAlignment(std::vector<float>&& aln) override {
|
||||
const std::vector<WordAlignment>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
|
||||
void setGuidedAlignment(std::vector<WordAlignment>&& aln) override {
|
||||
guidedAlignment_ = std::move(aln);
|
||||
}
|
||||
|
||||
size_t locateInGuidedAlignments(size_t b, size_t s, size_t t) {
|
||||
return ((s * size()) + b) * widthTrg() + t;
|
||||
}
|
||||
|
||||
std::vector<float>& getDataWeights() { return dataWeights_; }
|
||||
void setDataWeights(const std::vector<float>& weights) override {
|
||||
dataWeights_ = weights;
|
||||
|
@ -77,7 +77,7 @@ public:
|
||||
|
||||
size_t size() const override { return inputs_.front().shape()[0]; }
|
||||
|
||||
void setGuidedAlignment(std::vector<float>&&) override {
|
||||
void setGuidedAlignment(std::vector<WordAlignment>&&) override {
|
||||
ABORT("Guided alignment in DataBatch is not implemented");
|
||||
}
|
||||
void setDataWeights(const std::vector<float>&) override {
|
||||
|
@ -286,6 +286,8 @@ Expr operator/(float a, Expr b) {
|
||||
/*********************************************************/
|
||||
|
||||
Expr concatenate(const std::vector<Expr>& concats, int ax) {
|
||||
if(concats.size() == 1)
|
||||
return concats[0];
|
||||
return Expression<ConcatenateNodeOp>(concats, ax);
|
||||
}
|
||||
|
||||
|
@ -5,61 +5,56 @@
|
||||
|
||||
namespace marian {
|
||||
|
||||
static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> /*graph*/,
|
||||
static inline const std::tuple<std::vector<IndexType>, std::vector<float>>
|
||||
guidedAlignmentToSparse(Ptr<data::CorpusBatch> batch) {
|
||||
int trgWords = (int)batch->back()->batchWidth();
|
||||
int dimBatch = (int)batch->size();
|
||||
|
||||
typedef std::tuple<size_t, float> BiPoint;
|
||||
std::vector<BiPoint> byIndex;
|
||||
byIndex.reserve(dimBatch * trgWords);
|
||||
|
||||
for(size_t b = 0; b < dimBatch; ++b) {
|
||||
auto guidedAlignmentFwd = batch->getGuidedAlignment()[b]; // this copies
|
||||
guidedAlignmentFwd.normalize(/*reverse=*/false); // normalize forward
|
||||
for(size_t i = 0; i < guidedAlignmentFwd.size(); ++i) {
|
||||
auto pFwd = guidedAlignmentFwd[i];
|
||||
IndexType idx = (IndexType)(pFwd.srcPos * dimBatch * trgWords + b * trgWords + pFwd.tgtPos);
|
||||
byIndex.push_back({idx, pFwd.prob});
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(byIndex.begin(), byIndex.end(), [](const BiPoint& a, const BiPoint& b) { return std::get<0>(a) < std::get<0>(b); });
|
||||
std::vector<IndexType> indices; std::vector<float> valuesFwd;
|
||||
indices.reserve(byIndex.size()); valuesFwd.reserve(byIndex.size());
|
||||
for(auto& p : byIndex) {
|
||||
indices.push_back((IndexType)std::get<0>(p));
|
||||
valuesFwd.push_back(std::get<1>(p));
|
||||
}
|
||||
|
||||
return {indices, valuesFwd};
|
||||
}
|
||||
|
||||
static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> graph,
|
||||
Ptr<data::CorpusBatch> batch,
|
||||
Ptr<Options> options,
|
||||
Expr attention) { // [beam depth=1, max src length, batch size, tgt length]
|
||||
|
||||
std::string guidedLossType = options->get<std::string>("guided-alignment-cost"); // @TODO: change "cost" to "loss"
|
||||
|
||||
// We dropped support for other losses which are not possible to implement with sparse labels.
|
||||
// They were most likely not used anyway.
|
||||
ABORT_IF(guidedLossType != "ce", "Only alignment loss type 'ce' is supported");
|
||||
|
||||
float guidedLossWeight = options->get<float>("guided-alignment-weight");
|
||||
|
||||
const auto& shape = attention->shape(); // [beam depth=1, max src length, batch size, tgt length]
|
||||
auto [indices, values] = guidedAlignmentToSparse(batch);
|
||||
auto alignmentIndices = graph->indices(indices);
|
||||
auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values));
|
||||
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
|
||||
|
||||
float epsilon = 1e-6f;
|
||||
Expr alignmentLoss; // sum up loss over all attention/alignment positions
|
||||
size_t numLabels;
|
||||
if(guidedLossType == "ce") {
|
||||
// normalizedAlignment is multi-hot, but ce requires normalized probabilities, so need to normalize to P(s|t)
|
||||
auto dimBatch = shape[-2];
|
||||
auto dimTrgWords = shape[-1];
|
||||
auto dimSrcWords = shape[-3];
|
||||
ABORT_IF(shape[-4] != 1, "Guided alignments with beam??");
|
||||
auto normalizedAlignment = batch->getGuidedAlignment(); // [dimSrcWords, dimBatch, dimTrgWords] flattened, matches shape of 'attention'
|
||||
auto srcBatch = batch->front();
|
||||
const auto& srcMask = srcBatch->mask();
|
||||
ABORT_IF(shape.elements() != normalizedAlignment.size(), "Attention-matrix and alignment shapes differ??");
|
||||
ABORT_IF(dimBatch != batch->size() || dimTrgWords != batch->widthTrg() || dimSrcWords != batch->width(), "Attention-matrix and batch shapes differ??");
|
||||
auto locate = [=](size_t s, size_t b, size_t t) { return ((s * dimBatch) + b) * dimTrgWords + t; };
|
||||
for (size_t b = 0; b < dimBatch; b++) {
|
||||
for (size_t t = 0; t < dimTrgWords; t++) {
|
||||
for (size_t s = 0; s < dimSrcWords; s++)
|
||||
ABORT_IF(locate(s, b, t) != batch->locateInGuidedAlignments(b, s, t), "locate() and locateInGuidedAlignments() differ??");
|
||||
// renormalize the alignment such that it sums up to 1
|
||||
float sum = 0;
|
||||
for (size_t s = 0; s < dimSrcWords; s++)
|
||||
sum += srcMask[srcBatch->locate(b, s)] * normalizedAlignment[locate(s, b, t)]; // these values are 0 or 1
|
||||
if (sum != 0 && sum != 1)
|
||||
for (size_t s = 0; s < dimSrcWords; s++)
|
||||
normalizedAlignment[locate(s, b, t)] /= sum;
|
||||
}
|
||||
}
|
||||
auto alignment = constant_like(attention, std::move(normalizedAlignment));
|
||||
alignmentLoss = -sum(flatten(alignment * log(attention + epsilon)));
|
||||
numLabels = batch->back()->batchWords();
|
||||
ABORT_IF(numLabels > shape.elements() / shape[-3], "Num labels of guided alignment cost is off??");
|
||||
} else {
|
||||
auto alignment = constant_like(attention, batch->getGuidedAlignment());
|
||||
if(guidedLossType == "mse")
|
||||
alignmentLoss = sum(flatten(square(attention - alignment))) / 2.f;
|
||||
else if(guidedLossType == "mult") // @TODO: I don't know what this criterion is for. Can we remove it?
|
||||
alignmentLoss = -log(sum(flatten(attention * alignment)) + epsilon);
|
||||
else
|
||||
ABORT("Unknown alignment cost type: {}", guidedLossType);
|
||||
// every position is a label as they should all agree
|
||||
// @TODO: there should be positional masking here ... on the other hand, positions that are not
|
||||
// in a sentence should always agree (both being 0). Lack of masking affects label count only which is
|
||||
// probably negligible?
|
||||
numLabels = shape.elements();
|
||||
}
|
||||
Expr alignmentLoss = -sum(alignmentValues * log(attentionAtAligned + epsilon));
|
||||
size_t numLabels = alignmentIndices->shape().elements();
|
||||
|
||||
// Create label node, also weigh by scalar so labels and cost are in the same domain.
|
||||
// Fractional label counts are OK. But only if combined as "sum".
|
||||
|
Loading…
Reference in New Issue
Block a user