factor mask

This commit is contained in:
Hieu Hoang 2021-04-29 00:44:30 -07:00
parent f41acb1aa8
commit 6b2b7d1188
2 changed files with 44 additions and 1 deletions

View File

@ -101,7 +101,28 @@ Expr Logits::getFactoredLogits(size_t groupIndex,
factorMasks = constant(getFactorMasks(g, std::vector<WordIndex>()));
}
else {
factorMasks = constant(getFactorMasks(g, shortlist->indices()));
//std::cerr << "sel=" << sel->shape() << std::endl;
int currBeamSize = sel->shape()[0];
int batchSize = sel->shape()[2];
auto forward = [this, g, currBeamSize, batchSize](Expr out, const std::vector<Expr>& inputs) {
std::vector<WordIndex> indices;
Expr lastIndices = inputs[0];
lastIndices->val()->get(indices);
std::vector<float> masks = getFactorMasks2(batchSize, currBeamSize, g, indices);
out->val()->set(masks);
};
Expr lastIndices = shortlist->getIndicesExpr(batchSize, currBeamSize);
//std::cerr << "lastIndices=" << lastIndices->shape() << std::endl;
factorMasks = lambda({lastIndices}, lastIndices->shape(), Type::float32, forward);
//std::cerr << "factorMasks.1=" << factorMasks->shape() << std::endl;
factorMasks = transpose(factorMasks, {1, 0, 2});
//std::cerr << "factorMasks.2=" << factorMasks->shape() << std::endl;
const Shape &s = factorMasks->shape();
factorMasks = reshape(factorMasks, {s[0], 1, s[1], s[2]});
//std::cerr << "factorMasks.3=" << factorMasks->shape() << std::endl;
}
factorMaxima = cast(factorMaxima, sel->value_type());
factorMasks = cast(factorMasks, sel->value_type());
@ -219,6 +240,27 @@ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<
return res;
}
std::vector<float> Logits::getFactorMasks2(int batchSize, int currBeamSize, size_t factorGroup, const std::vector<WordIndex>& indices)
const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
size_t n
= indices.empty()
? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first)
: indices.size() / currBeamSize;
std::vector<float> res;
res.reserve(currBeamSize * n);
// @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this
// into FactoredVocab
for (size_t currBeam = 0; currBeam < currBeamSize; ++currBeam) {
for(size_t i = 0; i < n; i++) {
size_t idx = currBeam * n + i;
size_t lemma = indices.empty() ? i : (indices[idx] - factoredVocab_->getGroupRange(0).first);
res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
}
}
return res;
}
Logits Logits::applyUnaryFunction(
const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
std::vector<Ptr<RationalLoss>> newLogits;

View File

@ -80,6 +80,7 @@ private:
} // actually the same as constant(data) for this data type
std::vector<float> getFactorMasks(size_t factorGroup,
const std::vector<WordIndex>& indices) const;
std::vector<float> getFactorMasks2(int batchSize, int currBeamSize, size_t factorGroup, const std::vector<WordIndex>& indices) const;
private:
// members