mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 22524: Optimize guided alignment training speed via sparse alignments - part 1
This replaces dense alignment storage and training with a sparse representation. Training speed with guided alignment matches now nearly normal training speed, regaining about 25% speed. This is no. 1 of 2 PRs. The next one will introduce a new guided-alignment training scheme with better alignment accuracy.
This commit is contained in:
parent
3b21ff39c5
commit
4b51dcbd06
@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
- Make guided-alignment faster via sparse memory layout, add alignment points for EOS, remove losses other than ce.
|
||||||
- Changed minimal C++ standard to C++-17
|
- Changed minimal C++ standard to C++-17
|
||||||
- Faster LSH top-k search on CPU
|
- Faster LSH top-k search on CPU
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 0716f4e012d1e3f7543bffa8aecc97ce9c903e17
|
Subproject commit d59f7ad85ecfdf4a788c095ac9fc1c447094e39e
|
@ -510,7 +510,7 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
|
|||||||
"none");
|
"none");
|
||||||
cli.add<std::string>("--guided-alignment-cost",
|
cli.add<std::string>("--guided-alignment-cost",
|
||||||
"Cost type for guided alignment: ce (cross-entropy), mse (mean square error), mult (multiplication)",
|
"Cost type for guided alignment: ce (cross-entropy), mse (mean square error), mult (multiplication)",
|
||||||
"mse");
|
"ce");
|
||||||
cli.add<double>("--guided-alignment-weight",
|
cli.add<double>("--guided-alignment-weight",
|
||||||
"Weight for guided alignment cost",
|
"Weight for guided alignment cost",
|
||||||
0.1);
|
0.1);
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
#include "common/utils.h"
|
#include "common/utils.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <set>
|
||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -10,10 +12,11 @@ WordAlignment::WordAlignment() {}
|
|||||||
|
|
||||||
WordAlignment::WordAlignment(const std::vector<Point>& align) : data_(align) {}
|
WordAlignment::WordAlignment(const std::vector<Point>& align) : data_(align) {}
|
||||||
|
|
||||||
WordAlignment::WordAlignment(const std::string& line) {
|
WordAlignment::WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos) {
|
||||||
std::vector<std::string> atok = utils::splitAny(line, " -");
|
std::vector<std::string> atok = utils::splitAny(line, " -");
|
||||||
for(size_t i = 0; i < atok.size(); i += 2)
|
for(size_t i = 0; i < atok.size(); i += 2)
|
||||||
data_.emplace_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f });
|
data_.push_back(Point{ (size_t)std::stoi(atok[i]), (size_t)std::stoi(atok[i + 1]), 1.f });
|
||||||
|
data_.push_back(Point{ srcEosPos, tgtEosPos, 1.f }); // add alignment point for both EOS symbols
|
||||||
}
|
}
|
||||||
|
|
||||||
void WordAlignment::sort() {
|
void WordAlignment::sort() {
|
||||||
@ -22,6 +25,35 @@ void WordAlignment::sort() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void WordAlignment::normalize(bool reverse/*=false*/) {
|
||||||
|
std::vector<size_t> counts;
|
||||||
|
counts.reserve(data_.size());
|
||||||
|
|
||||||
|
// reverse==false : normalize target word prob by number of source words
|
||||||
|
// reverse==true : normalize source word prob by number of target words
|
||||||
|
auto srcOrTgt = [](const Point& p, bool reverse) {
|
||||||
|
return reverse ? p.srcPos : p.tgtPos;
|
||||||
|
};
|
||||||
|
|
||||||
|
for(const auto& a : data_) {
|
||||||
|
size_t pos = srcOrTgt(a, reverse);
|
||||||
|
if(counts.size() <= pos)
|
||||||
|
counts.resize(pos + 1, 0);
|
||||||
|
counts[pos]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// a.prob at this point is either 1 or normalized to a different value,
|
||||||
|
// but we just set it to 1 / count, so multiple calls result in re-normalization
|
||||||
|
// regardless of forward or reverse direction. We also set the remaining values to 1.
|
||||||
|
for(auto& a : data_) {
|
||||||
|
size_t pos = srcOrTgt(a, reverse);
|
||||||
|
if(counts[pos] > 1)
|
||||||
|
a.prob = 1.f / counts[pos];
|
||||||
|
else
|
||||||
|
a.prob = 1.f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string WordAlignment::toString() const {
|
std::string WordAlignment::toString() const {
|
||||||
std::stringstream str;
|
std::stringstream str;
|
||||||
for(auto p = begin(); p != end(); ++p) {
|
for(auto p = begin(); p != end(); ++p) {
|
||||||
@ -32,7 +64,7 @@ std::string WordAlignment::toString() const {
|
|||||||
return str.str();
|
return str.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft,
|
||||||
float threshold /*= 1.f*/) {
|
float threshold /*= 1.f*/) {
|
||||||
WordAlignment align;
|
WordAlignment align;
|
||||||
// Alignments by maximum value
|
// Alignments by maximum value
|
||||||
@ -58,7 +90,6 @@ WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort alignment pairs in ascending order
|
// Sort alignment pairs in ascending order
|
||||||
align.sort();
|
align.sort();
|
||||||
|
|
||||||
|
@ -1,20 +1,22 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <tuple>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
class WordAlignment {
|
class WordAlignment {
|
||||||
struct Point
|
public:
|
||||||
{
|
struct Point {
|
||||||
size_t srcPos;
|
size_t srcPos;
|
||||||
size_t tgtPos;
|
size_t tgtPos;
|
||||||
float prob;
|
float prob;
|
||||||
};
|
};
|
||||||
private:
|
private:
|
||||||
std::vector<Point> data_;
|
std::vector<Point> data_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
WordAlignment();
|
WordAlignment();
|
||||||
|
|
||||||
@ -28,11 +30,14 @@ private:
|
|||||||
public:
|
public:
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Constructs word alignments from textual representation.
|
* @brief Constructs word alignments from textual representation. Adds alignment point for externally
|
||||||
|
* supplied EOS positions in source and target string.
|
||||||
*
|
*
|
||||||
* @param line String in the form of "0-0 1-1 1-2", etc.
|
* @param line String in the form of "0-0 1-1 1-2", etc.
|
||||||
*/
|
*/
|
||||||
WordAlignment(const std::string& line);
|
WordAlignment(const std::string& line, size_t srcEosPos, size_t tgtEosPos);
|
||||||
|
|
||||||
|
Point& operator[](size_t i) { return data_[i]; }
|
||||||
|
|
||||||
auto begin() const -> decltype(data_.begin()) { return data_.begin(); }
|
auto begin() const -> decltype(data_.begin()) { return data_.begin(); }
|
||||||
auto end() const -> decltype(data_.end()) { return data_.end(); }
|
auto end() const -> decltype(data_.end()) { return data_.end(); }
|
||||||
@ -46,6 +51,12 @@ public:
|
|||||||
*/
|
*/
|
||||||
void sort();
|
void sort();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Normalizes alignment probabilities of target words to sum to 1 over source words alignments.
|
||||||
|
* This is needed for correct cost computation for guided alignment training with CE cost criterion.
|
||||||
|
*/
|
||||||
|
void normalize(bool reverse=false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Returns textual representation.
|
* @brief Returns textual representation.
|
||||||
*/
|
*/
|
||||||
@ -56,7 +67,7 @@ public:
|
|||||||
// Also used on QuickSAND boundary where beam and batch size is 1. Then it is simply [t][s] -> P(s|t)
|
// Also used on QuickSAND boundary where beam and batch size is 1. Then it is simply [t][s] -> P(s|t)
|
||||||
typedef std::vector<std::vector<float>> SoftAlignment; // [trg pos][beam depth * max src length * batch size]
|
typedef std::vector<std::vector<float>> SoftAlignment; // [trg pos][beam depth * max src length * batch size]
|
||||||
|
|
||||||
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
|
WordAlignment ConvertSoftAlignToHardAlign(const SoftAlignment& alignSoft,
|
||||||
float threshold = 1.f);
|
float threshold = 1.f);
|
||||||
|
|
||||||
std::string SoftAlignToString(SoftAlignment align);
|
std::string SoftAlignToString(SoftAlignment align);
|
||||||
|
@ -24,7 +24,7 @@ public:
|
|||||||
const std::vector<size_t>& getSentenceIds() const { return sentenceIds_; }
|
const std::vector<size_t>& getSentenceIds() const { return sentenceIds_; }
|
||||||
void setSentenceIds(const std::vector<size_t>& ids) { sentenceIds_ = ids; }
|
void setSentenceIds(const std::vector<size_t>& ids) { sentenceIds_ = ids; }
|
||||||
|
|
||||||
virtual void setGuidedAlignment(std::vector<float>&&) = 0;
|
virtual void setGuidedAlignment(std::vector<WordAlignment>&&) = 0;
|
||||||
virtual void setDataWeights(const std::vector<float>&) = 0;
|
virtual void setDataWeights(const std::vector<float>&) = 0;
|
||||||
virtual ~Batch() {};
|
virtual ~Batch() {};
|
||||||
protected:
|
protected:
|
||||||
|
@ -132,14 +132,13 @@ SentenceTuple Corpus::next() {
|
|||||||
tup.markAltered();
|
tup.markAltered();
|
||||||
addWordsToSentenceTuple(fields[i], vocabId, tup);
|
addWordsToSentenceTuple(fields[i], vocabId, tup);
|
||||||
}
|
}
|
||||||
|
|
||||||
// weights are added last to the sentence tuple, because this runs a validation that needs
|
|
||||||
// length of the target sequence
|
|
||||||
if(alignFileIdx_ > -1)
|
|
||||||
addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
|
|
||||||
if(weightFileIdx_ > -1)
|
|
||||||
addWeightsToSentenceTuple(fields[weightFileIdx_], tup);
|
|
||||||
}
|
}
|
||||||
|
// weights are added last to the sentence tuple, because this runs a validation that needs
|
||||||
|
// length of the target sequence
|
||||||
|
if(alignFileIdx_ > -1)
|
||||||
|
addAlignmentToSentenceTuple(fields[alignFileIdx_], tup);
|
||||||
|
if(weightFileIdx_ > -1)
|
||||||
|
addWeightsToSentenceTuple(fields[weightFileIdx_], tup);
|
||||||
|
|
||||||
// check if all streams are valid, that is, non-empty and no longer than maximum allowed length
|
// check if all streams are valid, that is, non-empty and no longer than maximum allowed length
|
||||||
if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
|
if(std::all_of(tup.begin(), tup.end(), [=](const Words& words) {
|
||||||
|
@ -429,11 +429,13 @@ void CorpusBase::addWordsToSentenceTuple(const std::string& line,
|
|||||||
|
|
||||||
void CorpusBase::addAlignmentToSentenceTuple(const std::string& line,
|
void CorpusBase::addAlignmentToSentenceTuple(const std::string& line,
|
||||||
SentenceTupleImpl& tup) const {
|
SentenceTupleImpl& tup) const {
|
||||||
ABORT_IF(rightLeft_,
|
ABORT_IF(rightLeft_, "Guided alignment and right-left model cannot be used together at the moment");
|
||||||
"Guided alignment and right-left model cannot be used "
|
ABORT_IF(tup.size() != 2, "Using alignment between source and target, but sentence tuple has {} elements??", tup.size());
|
||||||
"together at the moment");
|
|
||||||
|
|
||||||
auto align = WordAlignment(line);
|
size_t srcEosPos = tup[0].size() - 1;
|
||||||
|
size_t tgtEosPos = tup[1].size() - 1;
|
||||||
|
|
||||||
|
auto align = WordAlignment(line, srcEosPos, tgtEosPos);
|
||||||
tup.setAlignment(align);
|
tup.setAlignment(align);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -457,22 +459,17 @@ void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupl
|
|||||||
|
|
||||||
void CorpusBase::addAlignmentsToBatch(Ptr<CorpusBatch> batch,
|
void CorpusBase::addAlignmentsToBatch(Ptr<CorpusBatch> batch,
|
||||||
const std::vector<Sample>& batchVector) {
|
const std::vector<Sample>& batchVector) {
|
||||||
int srcWords = (int)batch->front()->batchWidth();
|
std::vector<WordAlignment> aligns;
|
||||||
int trgWords = (int)batch->back()->batchWidth();
|
|
||||||
int dimBatch = (int)batch->getSentenceIds().size();
|
int dimBatch = (int)batch->getSentenceIds().size();
|
||||||
|
aligns.reserve(dimBatch);
|
||||||
std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f);
|
|
||||||
|
|
||||||
for(int b = 0; b < dimBatch; ++b) {
|
for(int b = 0; b < dimBatch; ++b) {
|
||||||
|
|
||||||
// If the batch vector is altered within marian by, for example, case augmentation,
|
// If the batch vector is altered within marian by, for example, case augmentation,
|
||||||
// the guided alignments we received for this tuple cease to be valid.
|
// the guided alignments we received for this tuple cease to be valid.
|
||||||
// Hence skip setting alignments for that sentence tuple..
|
// Hence skip setting alignments for that sentence tuple..
|
||||||
if (!batchVector[b].isAltered()) {
|
if (!batchVector[b].isAltered()) {
|
||||||
for(auto p : batchVector[b].getAlignment()) {
|
aligns.push_back(std::move(batchVector[b].getAlignment()));
|
||||||
size_t idx = p.srcPos * dimBatch * trgWords + b * trgWords + p.tgtPos;
|
|
||||||
aligns[idx] = 1.f;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
batch->setGuidedAlignment(std::move(aligns));
|
batch->setGuidedAlignment(std::move(aligns));
|
||||||
|
@ -338,7 +338,7 @@ public:
|
|||||||
class CorpusBatch : public Batch {
|
class CorpusBatch : public Batch {
|
||||||
protected:
|
protected:
|
||||||
std::vector<Ptr<SubBatch>> subBatches_;
|
std::vector<Ptr<SubBatch>> subBatches_;
|
||||||
std::vector<float> guidedAlignment_; // [max source len, batch size, max target len] flattened
|
std::vector<WordAlignment> guidedAlignment_; // [max source len, batch size, max target len] flattened
|
||||||
std::vector<float> dataWeights_;
|
std::vector<float> dataWeights_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -444,8 +444,17 @@ public:
|
|||||||
|
|
||||||
if(options->get("guided-alignment", std::string("none")) != "none") {
|
if(options->get("guided-alignment", std::string("none")) != "none") {
|
||||||
// @TODO: if > 1 encoder, verify that all encoders have the same sentence lengths
|
// @TODO: if > 1 encoder, verify that all encoders have the same sentence lengths
|
||||||
std::vector<float> alignment(batchSize * lengths.front() * lengths.back(),
|
|
||||||
0.f);
|
std::vector<data::WordAlignment> alignment;
|
||||||
|
for(size_t k = 0; k < batchSize; ++k) {
|
||||||
|
data::WordAlignment perSentence;
|
||||||
|
// fill with random alignment points, add more twice the number of words to be safe.
|
||||||
|
for(size_t j = 0; j < lengths.back() * 2; ++j) {
|
||||||
|
size_t i = rand() % lengths.back();
|
||||||
|
perSentence.push_back(i, j, 1.0f);
|
||||||
|
}
|
||||||
|
alignment.push_back(std::move(perSentence));
|
||||||
|
}
|
||||||
batch->setGuidedAlignment(std::move(alignment));
|
batch->setGuidedAlignment(std::move(alignment));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -501,29 +510,14 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if(!guidedAlignment_.empty()) {
|
if(!guidedAlignment_.empty()) {
|
||||||
size_t oldTrgWords = back()->batchWidth();
|
|
||||||
size_t oldSize = size();
|
|
||||||
|
|
||||||
pos = 0;
|
pos = 0;
|
||||||
for(auto split : splits) {
|
for(auto split : splits) {
|
||||||
auto cb = std::static_pointer_cast<CorpusBatch>(split);
|
auto cb = std::static_pointer_cast<CorpusBatch>(split);
|
||||||
size_t srcWords = cb->front()->batchWidth();
|
|
||||||
size_t trgWords = cb->back()->batchWidth();
|
|
||||||
size_t dimBatch = cb->size();
|
size_t dimBatch = cb->size();
|
||||||
|
std::vector<WordAlignment> batchAlignment;
|
||||||
std::vector<float> aligns(srcWords * dimBatch * trgWords, 0.f);
|
for(size_t i = 0; i < dimBatch; ++i)
|
||||||
|
batchAlignment.push_back(std::move(guidedAlignment_[i + pos]));
|
||||||
for(size_t i = 0; i < dimBatch; ++i) {
|
cb->setGuidedAlignment(std::move(batchAlignment));
|
||||||
size_t bi = i + pos;
|
|
||||||
for(size_t sid = 0; sid < srcWords; ++sid) {
|
|
||||||
for(size_t tid = 0; tid < trgWords; ++tid) {
|
|
||||||
size_t bidx = sid * oldSize * oldTrgWords + bi * oldTrgWords + tid; // [sid, bi, tid]
|
|
||||||
size_t idx = sid * dimBatch * trgWords + i * trgWords + tid;
|
|
||||||
aligns[idx] = guidedAlignment_[bidx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cb->setGuidedAlignment(std::move(aligns));
|
|
||||||
pos += dimBatch;
|
pos += dimBatch;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -556,15 +550,11 @@ public:
|
|||||||
return splits;
|
return splits;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<float>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
|
const std::vector<WordAlignment>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
|
||||||
void setGuidedAlignment(std::vector<float>&& aln) override {
|
void setGuidedAlignment(std::vector<WordAlignment>&& aln) override {
|
||||||
guidedAlignment_ = std::move(aln);
|
guidedAlignment_ = std::move(aln);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t locateInGuidedAlignments(size_t b, size_t s, size_t t) {
|
|
||||||
return ((s * size()) + b) * widthTrg() + t;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<float>& getDataWeights() { return dataWeights_; }
|
std::vector<float>& getDataWeights() { return dataWeights_; }
|
||||||
void setDataWeights(const std::vector<float>& weights) override {
|
void setDataWeights(const std::vector<float>& weights) override {
|
||||||
dataWeights_ = weights;
|
dataWeights_ = weights;
|
||||||
|
@ -77,7 +77,7 @@ public:
|
|||||||
|
|
||||||
size_t size() const override { return inputs_.front().shape()[0]; }
|
size_t size() const override { return inputs_.front().shape()[0]; }
|
||||||
|
|
||||||
void setGuidedAlignment(std::vector<float>&&) override {
|
void setGuidedAlignment(std::vector<WordAlignment>&&) override {
|
||||||
ABORT("Guided alignment in DataBatch is not implemented");
|
ABORT("Guided alignment in DataBatch is not implemented");
|
||||||
}
|
}
|
||||||
void setDataWeights(const std::vector<float>&) override {
|
void setDataWeights(const std::vector<float>&) override {
|
||||||
|
@ -286,6 +286,8 @@ Expr operator/(float a, Expr b) {
|
|||||||
/*********************************************************/
|
/*********************************************************/
|
||||||
|
|
||||||
Expr concatenate(const std::vector<Expr>& concats, int ax) {
|
Expr concatenate(const std::vector<Expr>& concats, int ax) {
|
||||||
|
if(concats.size() == 1)
|
||||||
|
return concats[0];
|
||||||
return Expression<ConcatenateNodeOp>(concats, ax);
|
return Expression<ConcatenateNodeOp>(concats, ax);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,62 +5,57 @@
|
|||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
|
|
||||||
static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> /*graph*/,
|
static inline const std::tuple<std::vector<IndexType>, std::vector<float>>
|
||||||
|
guidedAlignmentToSparse(Ptr<data::CorpusBatch> batch) {
|
||||||
|
int trgWords = (int)batch->back()->batchWidth();
|
||||||
|
int dimBatch = (int)batch->size();
|
||||||
|
|
||||||
|
typedef std::tuple<size_t, float> BiPoint;
|
||||||
|
std::vector<BiPoint> byIndex;
|
||||||
|
byIndex.reserve(dimBatch * trgWords);
|
||||||
|
|
||||||
|
for(size_t b = 0; b < dimBatch; ++b) {
|
||||||
|
auto guidedAlignmentFwd = batch->getGuidedAlignment()[b]; // this copies
|
||||||
|
guidedAlignmentFwd.normalize(/*reverse=*/false); // normalize forward
|
||||||
|
for(size_t i = 0; i < guidedAlignmentFwd.size(); ++i) {
|
||||||
|
auto pFwd = guidedAlignmentFwd[i];
|
||||||
|
IndexType idx = (IndexType)(pFwd.srcPos * dimBatch * trgWords + b * trgWords + pFwd.tgtPos);
|
||||||
|
byIndex.push_back({idx, pFwd.prob});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(byIndex.begin(), byIndex.end(), [](const BiPoint& a, const BiPoint& b) { return std::get<0>(a) < std::get<0>(b); });
|
||||||
|
std::vector<IndexType> indices; std::vector<float> valuesFwd;
|
||||||
|
indices.reserve(byIndex.size()); valuesFwd.reserve(byIndex.size());
|
||||||
|
for(auto& p : byIndex) {
|
||||||
|
indices.push_back((IndexType)std::get<0>(p));
|
||||||
|
valuesFwd.push_back(std::get<1>(p));
|
||||||
|
}
|
||||||
|
|
||||||
|
return {indices, valuesFwd};
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> graph,
|
||||||
Ptr<data::CorpusBatch> batch,
|
Ptr<data::CorpusBatch> batch,
|
||||||
Ptr<Options> options,
|
Ptr<Options> options,
|
||||||
Expr attention) { // [beam depth=1, max src length, batch size, tgt length]
|
Expr attention) { // [beam depth=1, max src length, batch size, tgt length]
|
||||||
|
|
||||||
std::string guidedLossType = options->get<std::string>("guided-alignment-cost"); // @TODO: change "cost" to "loss"
|
std::string guidedLossType = options->get<std::string>("guided-alignment-cost"); // @TODO: change "cost" to "loss"
|
||||||
|
|
||||||
|
// We dropped support for other losses which are not possible to implement with sparse labels.
|
||||||
|
// They were most likely not used anyway.
|
||||||
|
ABORT_IF(guidedLossType != "ce", "Only alignment loss type 'ce' is supported");
|
||||||
|
|
||||||
float guidedLossWeight = options->get<float>("guided-alignment-weight");
|
float guidedLossWeight = options->get<float>("guided-alignment-weight");
|
||||||
|
|
||||||
const auto& shape = attention->shape(); // [beam depth=1, max src length, batch size, tgt length]
|
auto [indices, values] = guidedAlignmentToSparse(batch);
|
||||||
float epsilon = 1e-6f;
|
auto alignmentIndices = graph->indices(indices);
|
||||||
Expr alignmentLoss; // sum up loss over all attention/alignment positions
|
auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values));
|
||||||
size_t numLabels;
|
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
|
||||||
if(guidedLossType == "ce") {
|
|
||||||
// normalizedAlignment is multi-hot, but ce requires normalized probabilities, so need to normalize to P(s|t)
|
|
||||||
auto dimBatch = shape[-2];
|
|
||||||
auto dimTrgWords = shape[-1];
|
|
||||||
auto dimSrcWords = shape[-3];
|
|
||||||
ABORT_IF(shape[-4] != 1, "Guided alignments with beam??");
|
|
||||||
auto normalizedAlignment = batch->getGuidedAlignment(); // [dimSrcWords, dimBatch, dimTrgWords] flattened, matches shape of 'attention'
|
|
||||||
auto srcBatch = batch->front();
|
|
||||||
const auto& srcMask = srcBatch->mask();
|
|
||||||
ABORT_IF(shape.elements() != normalizedAlignment.size(), "Attention-matrix and alignment shapes differ??");
|
|
||||||
ABORT_IF(dimBatch != batch->size() || dimTrgWords != batch->widthTrg() || dimSrcWords != batch->width(), "Attention-matrix and batch shapes differ??");
|
|
||||||
auto locate = [=](size_t s, size_t b, size_t t) { return ((s * dimBatch) + b) * dimTrgWords + t; };
|
|
||||||
for (size_t b = 0; b < dimBatch; b++) {
|
|
||||||
for (size_t t = 0; t < dimTrgWords; t++) {
|
|
||||||
for (size_t s = 0; s < dimSrcWords; s++)
|
|
||||||
ABORT_IF(locate(s, b, t) != batch->locateInGuidedAlignments(b, s, t), "locate() and locateInGuidedAlignments() differ??");
|
|
||||||
// renormalize the alignment such that it sums up to 1
|
|
||||||
float sum = 0;
|
|
||||||
for (size_t s = 0; s < dimSrcWords; s++)
|
|
||||||
sum += srcMask[srcBatch->locate(b, s)] * normalizedAlignment[locate(s, b, t)]; // these values are 0 or 1
|
|
||||||
if (sum != 0 && sum != 1)
|
|
||||||
for (size_t s = 0; s < dimSrcWords; s++)
|
|
||||||
normalizedAlignment[locate(s, b, t)] /= sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto alignment = constant_like(attention, std::move(normalizedAlignment));
|
|
||||||
alignmentLoss = -sum(flatten(alignment * log(attention + epsilon)));
|
|
||||||
numLabels = batch->back()->batchWords();
|
|
||||||
ABORT_IF(numLabels > shape.elements() / shape[-3], "Num labels of guided alignment cost is off??");
|
|
||||||
} else {
|
|
||||||
auto alignment = constant_like(attention, batch->getGuidedAlignment());
|
|
||||||
if(guidedLossType == "mse")
|
|
||||||
alignmentLoss = sum(flatten(square(attention - alignment))) / 2.f;
|
|
||||||
else if(guidedLossType == "mult") // @TODO: I don't know what this criterion is for. Can we remove it?
|
|
||||||
alignmentLoss = -log(sum(flatten(attention * alignment)) + epsilon);
|
|
||||||
else
|
|
||||||
ABORT("Unknown alignment cost type: {}", guidedLossType);
|
|
||||||
// every position is a label as they should all agree
|
|
||||||
// @TODO: there should be positional masking here ... on the other hand, positions that are not
|
|
||||||
// in a sentence should always agree (both being 0). Lack of masking affects label count only which is
|
|
||||||
// probably negligible?
|
|
||||||
numLabels = shape.elements();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
float epsilon = 1e-6f;
|
||||||
|
Expr alignmentLoss = -sum(alignmentValues * log(attentionAtAligned + epsilon));
|
||||||
|
size_t numLabels = alignmentIndices->shape().elements();
|
||||||
|
|
||||||
// Create label node, also weigh by scalar so labels and cost are in the same domain.
|
// Create label node, also weigh by scalar so labels and cost are in the same domain.
|
||||||
// Fractional label counts are OK. But only if combined as "sum".
|
// Fractional label counts are OK. But only if combined as "sum".
|
||||||
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires
|
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires
|
||||||
|
Loading…
Reference in New Issue
Block a user