mirror of
https://github.com/marian-nmt/marian.git
synced 2024-10-05 19:17:10 +03:00
Merged PR 26476: Sanitize guided-alignment with case-augmentation (still somewhat wonky)
This fixes the blow-ups of using case-augmentation with guided-alignment. However, it's still not recommended to use this particular combination, results will be unreliable.
This commit is contained in:
parent
4f145c450f
commit
9ad5203ca2
@ -128,7 +128,7 @@ SentenceTuple Corpus::next() {
|
||||
size_t vocabId = i - shift;
|
||||
bool altered;
|
||||
preprocessLine(fields[i], vocabId, curId, /*out=*/altered);
|
||||
if (altered)
|
||||
if(altered)
|
||||
tup.markAltered();
|
||||
addWordsToSentenceTuple(fields[i], vocabId, tup);
|
||||
}
|
||||
|
@ -476,7 +476,10 @@ void CorpusBase::addAlignmentsToBatch(Ptr<CorpusBatch> batch,
|
||||
// If the batch vector is altered within marian by, for example, case augmentation,
|
||||
// the guided alignments we received for this tuple cease to be valid.
|
||||
// Hence skip setting alignments for that sentence tuple..
|
||||
if (!batchVector[b].isAltered()) {
|
||||
if (batchVector[b].isAltered()) {
|
||||
LOG_ONCE(info, "Using guided-alignment with case-augmentation is not recommended and can result in unexpected behavior");
|
||||
aligns.push_back(WordAlignment());
|
||||
} else {
|
||||
aligns.push_back(std::move(batchVector[b].getAlignment()));
|
||||
}
|
||||
}
|
||||
|
@ -56,12 +56,16 @@ public:
|
||||
* @brief Returns whether this Tuple was altered or augmented from what
|
||||
* was provided to Marian in input.
|
||||
*/
|
||||
bool isAltered() const { return altered_; }
|
||||
bool isAltered() const {
|
||||
return altered_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Mark that this Tuple was internally altered or augmented by Marian
|
||||
*/
|
||||
void markAltered() { altered_ = true; }
|
||||
void markAltered() {
|
||||
altered_ = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Adds a new sentence at the end of the tuple.
|
||||
|
@ -64,6 +64,14 @@ Expr ExpressionGraph::add(Expr node) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes the node from the set of roots (will not be initialized during back propagation)
|
||||
* @param node a pointer to a expression node
|
||||
*/
|
||||
void ExpressionGraph::removeAsRoot(Expr node) {
|
||||
topNodes_.erase(node);
|
||||
}
|
||||
|
||||
// Call on every checkpoint in backwards order
|
||||
void createSubtape(Expr node) {
|
||||
auto subtape = New<std::list<Expr>>();
|
||||
|
@ -676,6 +676,12 @@ public:
|
||||
* @param node a pointer to a expression node
|
||||
*/
|
||||
Expr add(Expr node);
|
||||
|
||||
/**
|
||||
* Removes the node from the set of roots (will not be initialized during back propagation)
|
||||
* @param node a pointer to a expression node
|
||||
*/
|
||||
void removeAsRoot(Expr node);
|
||||
|
||||
/**
|
||||
* Allocate memory for the forward pass of the given node.
|
||||
|
@ -27,6 +27,11 @@ Expr checkpoint(Expr a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
Expr removeAsRoot(Expr a) {
|
||||
a->graph()->removeAsRoot(a); // ugly, hence why hidden here
|
||||
return a;
|
||||
}
|
||||
|
||||
Expr lambda(const std::vector<Expr>& nodes, Shape shape, Type type,
|
||||
LambdaNodeFunctor fwd, size_t hash) {
|
||||
return Expression<LambdaNodeOp>(nodes, shape, type, fwd, hash);
|
||||
|
@ -16,6 +16,11 @@ Expr debug(Expr a, const std::string& message = "");
|
||||
*/
|
||||
Expr checkpoint(Expr a);
|
||||
|
||||
/**
|
||||
* Removes the node from the set of root nodes, no-op otherwise
|
||||
*/
|
||||
Expr removeAsRoot(Expr node);
|
||||
|
||||
typedef Expr(ActivationFunction)(Expr); ///< ActivationFunction has signature Expr(Expr)
|
||||
|
||||
/**
|
||||
|
@ -26,7 +26,8 @@ guidedAlignmentToSparse(Ptr<data::CorpusBatch> batch) {
|
||||
|
||||
std::sort(byIndex.begin(), byIndex.end(), [](const BiPoint& a, const BiPoint& b) { return std::get<0>(a) < std::get<0>(b); });
|
||||
std::vector<IndexType> indices; std::vector<float> valuesFwd;
|
||||
indices.reserve(byIndex.size()); valuesFwd.reserve(byIndex.size());
|
||||
indices.reserve(byIndex.size());
|
||||
valuesFwd.reserve(byIndex.size());
|
||||
for(auto& p : byIndex) {
|
||||
indices.push_back((IndexType)std::get<0>(p));
|
||||
valuesFwd.push_back(std::get<1>(p));
|
||||
@ -40,28 +41,33 @@ static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> graph,
|
||||
Ptr<Options> options,
|
||||
Expr attention) { // [beam depth=1, max src length, batch size, tgt length]
|
||||
std::string guidedLossType = options->get<std::string>("guided-alignment-cost"); // @TODO: change "cost" to "loss"
|
||||
|
||||
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires
|
||||
// a substantial rewrite of the multi-loss architecture, which is planned anyways.
|
||||
std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
|
||||
|
||||
// We dropped support for other losses which are not possible to implement with sparse labels.
|
||||
// They were most likely not used anyway.
|
||||
ABORT_IF(guidedLossType != "ce", "Only alignment loss type 'ce' is supported");
|
||||
|
||||
float guidedLossWeight = options->get<float>("guided-alignment-weight");
|
||||
|
||||
auto [indices, values] = guidedAlignmentToSparse(batch);
|
||||
auto alignmentIndices = graph->indices(indices);
|
||||
auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values));
|
||||
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
|
||||
|
||||
float epsilon = 1e-6f;
|
||||
Expr alignmentLoss = -sum(cast(alignmentValues * log(attentionAtAligned + epsilon), Type::float32));
|
||||
size_t numLabels = alignmentIndices->shape().elements();
|
||||
|
||||
const auto& [indices, values] = guidedAlignmentToSparse(batch);
|
||||
|
||||
Expr alignmentLoss;
|
||||
size_t numLabels = indices.size(); // can be zero
|
||||
if(indices.empty()) {
|
||||
removeAsRoot(stopGradient(attention)); // unused, hence make sure we don't polute the backwards operations
|
||||
alignmentLoss = graph->zeros({1});
|
||||
numLabels = multiLossType == "sum" ? 0 : 1;
|
||||
} else {
|
||||
float epsilon = 1e-6f;
|
||||
auto alignmentIndices = graph->indices(indices);
|
||||
auto alignmentValues = graph->constant({(int)values.size()}, inits::fromVector(values));
|
||||
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
|
||||
alignmentLoss = -sum(cast(alignmentValues * log(attentionAtAligned + epsilon), Type::float32));
|
||||
}
|
||||
// Create label node, also weigh by scalar so labels and cost are in the same domain.
|
||||
// Fractional label counts are OK. But only if combined as "sum".
|
||||
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires
|
||||
// a substantial rewrite of the multi-loss architecture, which is planned anyways.
|
||||
std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
|
||||
if (multiLossType == "sum") // sum of sums
|
||||
if (multiLossType == "sum") // sum of sums
|
||||
return RationalLoss(guidedLossWeight * alignmentLoss, guidedLossWeight * numLabels);
|
||||
else
|
||||
return RationalLoss(guidedLossWeight * alignmentLoss, (float)numLabels);
|
||||
|
Loading…
Reference in New Issue
Block a user