mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
factor mask
This commit is contained in:
parent
f41acb1aa8
commit
6b2b7d1188
@ -101,7 +101,28 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
|
||||
factorMasks = constant(getFactorMasks(g, std::vector<WordIndex>()));
|
||||
}
|
||||
else {
|
||||
factorMasks = constant(getFactorMasks(g, shortlist->indices()));
|
||||
//std::cerr << "sel=" << sel->shape() << std::endl;
|
||||
int currBeamSize = sel->shape()[0];
|
||||
int batchSize = sel->shape()[2];
|
||||
|
||||
auto forward = [this, g, currBeamSize, batchSize](Expr out, const std::vector<Expr>& inputs) {
|
||||
std::vector<WordIndex> indices;
|
||||
Expr lastIndices = inputs[0];
|
||||
lastIndices->val()->get(indices);
|
||||
std::vector<float> masks = getFactorMasks2(batchSize, currBeamSize, g, indices);
|
||||
out->val()->set(masks);
|
||||
};
|
||||
|
||||
Expr lastIndices = shortlist->getIndicesExpr(batchSize, currBeamSize);
|
||||
//std::cerr << "lastIndices=" << lastIndices->shape() << std::endl;
|
||||
factorMasks = lambda({lastIndices}, lastIndices->shape(), Type::float32, forward);
|
||||
//std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl;
|
||||
factorMasks = transpose(factorMasks, {1, 0, 2});
|
||||
//std::cerr << "factorMasks.2=" << factorMasks->shape() << std::endl;
|
||||
|
||||
const Shape &s = factorMasks->shape();
|
||||
factorMasks = reshape(factorMasks, {s[0], 1, s[1], s[2]});
|
||||
//std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl;
|
||||
}
|
||||
factorMaxima = cast(factorMaxima, sel->value_type());
|
||||
factorMasks = cast(factorMasks, sel->value_type());
|
||||
@ -219,6 +240,27 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<float> Logits::getFactorMasks2(int batchSize, int currBeamSize, size_t factorGroup, const std::vector<WordIndex>& indices)
|
||||
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
|
||||
size_t n
|
||||
= indices.empty()
|
||||
? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first)
|
||||
: indices.size() / currBeamSize;
|
||||
std::vector<float> res;
|
||||
res.reserve(currBeamSize * n);
|
||||
|
||||
// @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this
|
||||
// into FactoredVocab
|
||||
for (size_t currBeam = 0; currBeam < currBeamSize; ++currBeam) {
|
||||
for(size_t i = 0; i < n; i++) {
|
||||
size_t idx = currBeam * n + i;
|
||||
size_t lemma = indices.empty() ? i : (indices[idx] - factoredVocab_->getGroupRange(0).first);
|
||||
res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Logits Logits::applyUnaryFunction(
|
||||
const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
|
||||
std::vector<Ptr<RationalLoss>> newLogits;
|
||||
|
@ -80,6 +80,7 @@ private:
|
||||
} // actually the same as constant(data) for this data type
|
||||
std::vector<float> getFactorMasks(size_t factorGroup,
|
||||
const std::vector<WordIndex>& indices) const;
|
||||
std::vector<float> getFactorMasks2(int batchSize, int currBeamSize, size_t factorGroup, const std::vector<WordIndex>& indices) const;
|
||||
|
||||
private:
|
||||
// members
|
||||
|
Loading…
Reference in New Issue
Block a user