diff --git a/src/layers/logits.cpp b/src/layers/logits.cpp new file mode 100644 index 00000000..cd2203e4 --- /dev/null +++ b/src/layers/logits.cpp @@ -0,0 +1,212 @@ +#include "logits.h" +#include "loss.h" +#include "data/factored_vocab.h" +#include "rnn/types.h" // for State::select() + +namespace marian { + Logits::Logits(Expr logits) : Logits(New(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count) + + Ptr Logits::graph() const { + ABORT_IF(logits_.empty(), "Empty logits object??"); + return logits_.front()->loss()->graph(); + } + + // This function assumes that the object holds one or more factor logits. + // It applies the supplied loss function to each, and then returns the aggregate loss over all factors. + Expr Logits::applyLossFunction(const Words& labels, const std::function& lossFn) const { + LOG_ONCE(info, "[logits] Applying loss function for {} factor(s)", logits_.size()); + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + + auto firstLogits = logits_.front()->loss(); + ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(), + "Labels not matching logits shape ({} != {}, {})??", + labels.size() * firstLogits->shape()[-1], + firstLogits->shape().elements(), + firstLogits->shape()); + + // base case (no factors) + if (!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return lossFn(firstLogits, indices(toWordIndexVector(labels))); + } + + auto numGroups = factoredVocab_->getNumGroups(); + + // split labels into individual factor labels + auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened] + + //Expr indices = this->indices(toWordIndexVector(labels)); + // accumulate all CEs for all words that have the factor + // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors. + Expr loss; + for (size_t g = 0; g < numGroups; g++) { + if (!logits_[g]) + continue; // empty factor --@TODO: use an array of indices of non-empty logits_[] + const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask) + auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply + auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor + auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet) + // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next. + auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1] + if(loss) + factorLoss = cast(factorLoss, loss->value_type()); + factorLoss = factorLoss * cast(reshape(factorMask, factorLoss->shape()), factorLoss->value_type()); // mask out factor for words that do not have that factor + loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1] + } + return loss; + } + + // This function assumes this object holds a single factor that represents a rational loss (with count). + //Ptr Logits::getRationalLoss() const { + // ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs"); + // ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count"); + // return logits_.front(); + //} + + // get logits for one factor group + // For groupIndex == 0, the function also requires the shortlist if there is one. + Expr Logits::getFactoredLogits(size_t groupIndex, Ptr shortlist /*= nullptr*/, const std::vector& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + + auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab] + + // normalize for decoding: + // - all secondary factors: subtract their max + // - lemma: add all maxes of applicable factors + if (groupIndex > 0) { + sel = sel - max(sel, -1); + } + else { + auto numGroups = getNumFactorGroups(); + for (size_t g = 1; g < numGroups; g++) { + auto factorMaxima = max(logits_[g]->loss(), -1); // we cast since loss is likely ce-loss which has type float32 + auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector())); + sel = sel + cast(factorMaxima, sel->value_type()) * cast(factorMasks, sel->value_type()); // those lemmas that don't have a factor get multiplied with 0 + } + } + + // if selIdx are given, then we must reshuffle accordingly + if (!hypIndices.empty()) // use the same function that shuffles decoder state + sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false); + + return sel; + } + + // used for breakDown() only + // Index is flattened + Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + return logits_[groupIndex]->loss()->val(); + } + + // This function assumes that the object holds one or more factor logits, which are summed up + // into output-vocab logits according to the factored model (with correct normalization of factors). + // This is infeasible for realistic factor sets, and therefore only implemented for 1 factor. + // @TODO: remove altogether + Expr Logits::getLogits() const { + ABORT_IF(empty(), "Attempted to read out logits on empty Logits object"); + if (!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return getFactoredLogits(0); + } + +#ifdef FACTOR_FULL_EXPANSION + // compute normalized factor log probs + std::vector logProbs(logits_.size()); + for (size_t g = 0; g < logits_.size(); g++) + logProbs[g] = logsoftmax(logits_[g]->loss()); + auto y = concatenate(logProbs, /*axis=*/ -1); + + // sum up the unit logits across factors for each target word + auto graph = y->graph(); + auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U] + y = dot_csr( + y, // [B x U] + factorMatrix.shape, + graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights)), + graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32), + graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32), + /*transB=*/ true); // -> [B x V] + + // mask out gaps + auto gapLogMask = factoredVocab_->getGapLogMask(); // [V] + y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask)); + + return y; +#else + ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible +#endif + } + + void Logits::MaskedFactorIndices::push_back(size_t factorIndex) { + bool isValid = FactoredVocab::isFactorValid(factorIndex); + indices.push_back(isValid ? (WordIndex)factorIndex : 0); + masks.push_back((float)isValid); + } + + std::vector Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices + if (!factoredVocab_) { + ABORT_IF(logits_.size() != 1, "Factors without factor mappings??"); + return {MaskedFactorIndices(words)}; + } + auto numGroups = factoredVocab_->getNumGroups(); + std::vector res(numGroups); + for (size_t g = 0; g < numGroups; g++) { + auto& resg = res[g]; + resg.reserve(words.size()); + for (const auto& word : words) + resg.push_back(factoredVocab_->getFactor(word, g)); + } + return res; + } + + //// use first factor of each word to determine whether it has a specific factor + //std::vector Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0 + // std::vector res; + // res.reserve(words.size()); + // for (const auto& word : words) { + // auto lemma = factoredVocab_->getFactor(word, 0); + // res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); + // } + // return res; + //} + + // return a vector of 1 or 0 indicating for each lemma whether it has a specific factor + // If 'indices' is given, then return the masks for the indices; otherwise for all lemmas + std::vector Logits::getFactorMasks(size_t factorGroup, const std::vector& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0 + size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size(); + std::vector res; + res.reserve(n); + // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab + for (size_t i = 0; i < n; i++) { + auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first); + res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup)); + } + return res; + } + + Logits Logits::applyUnaryFunction(const std::function& f) const { // clone this but apply f to all loss values + std::vector> newLogits; + for (const auto& l : logits_) + newLogits.emplace_back(New(f(l->loss()), l->count())); + return Logits(std::move(newLogits), factoredVocab_); + } + + Logits Logits::applyUnaryFunctions(const std::function& f1, const std::function& fother) const { + std::vector> newLogits; + bool first = true; + for (const auto& l : logits_) { + newLogits.emplace_back(New((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others + first = false; + } + return Logits(std::move(newLogits), factoredVocab_); + } + + // @TODO: code dup with above; we can merge it into applyToRationalLoss() + Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_ + std::vector> newLogits; + for (const auto& l : logits_) + newLogits.emplace_back(New(l->loss(), count)); + return Logits(std::move(newLogits), factoredVocab_); + } +} \ No newline at end of file