fix after merge

This commit is contained in:
Frank Seide 2019-02-04 17:32:53 -08:00
commit de69efea79
11 changed files with 46 additions and 42 deletions

View File

@ -15,19 +15,19 @@ namespace data {
class Shortlist {
private:
std::vector<WordIndex> indices_;
std::vector<WordIndex> mappedIndices_;
Words mappedIndices_;
std::vector<WordIndex> reverseMap_;
public:
Shortlist(const std::vector<WordIndex>& indices,
const std::vector<WordIndex>& mappedIndices,
const Words& mappedIndices,
const std::vector<WordIndex>& reverseMap)
: indices_(indices),
mappedIndices_(mappedIndices),
reverseMap_(reverseMap) {}
std::vector<WordIndex>& indices() { return indices_; }
std::vector<WordIndex>& mappedIndices() { return mappedIndices_; }
const std::vector<WordIndex>& indices() const { return indices_; }
const Words& mappedIndices() const { return mappedIndices_; }
WordIndex reverseMap(WordIndex idx) { return reverseMap_[idx]; }
};
@ -103,10 +103,10 @@ public:
reverseMap.push_back(idx[i]);
}
std::vector<WordIndex> mapped;
Words mapped;
for(auto i : trgBatch->data()) {
// mapped postions for cross-entropy
mapped.push_back(pos[i.toWordIndex()]);
mapped.push_back(Word::fromWordIndex(pos[i.toWordIndex()]));
}
return New<Shortlist>(idx, mapped, reverseMap);
@ -263,7 +263,7 @@ public:
reverseMap.push_back(idx[i]);
}
std::vector<WordIndex> mapped;
Words mapped;
// for(auto i : trgBatch->data()) {
// mapped postions for cross-entropy
// mapped.push_back(pos[i]);
@ -289,7 +289,7 @@ public:
}
Ptr<Shortlist> generate(Ptr<data::CorpusBatch> /*batch*/) override {
std::vector<WordIndex> tmp;
Words tmp;
return New<Shortlist>(idx_, tmp, reverseIdx_);
}
};

View File

@ -180,7 +180,11 @@ namespace marian {
// This function assumes that the object holds one or more factor logits.
// It applies the supplied loss function to each, and then returns the aggregate loss over all factors.
Expr Logits::applyLossFunction(Expr indices, const std::function<Expr(Expr/*logits*/, Expr/*indices*/)>& lossFn) const {
Expr Logits::applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/, Expr/*indices*/)>& lossFn) const {
ABORT_IF(logits_.empty(), "Empty logits object??");
auto graph = logits_.front()->loss()->graph();
Expr indices = graph->indices(toWordIndexVector(labels));
LOG_ONCE(info, "[logits] applyLossFunction() for {} factors", logits_.size());
ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
if (!embeddingFactorMapping_) {
@ -190,7 +194,6 @@ namespace marian {
// accumulate all CEs for all words that have the factor
// Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors.
auto graph = indices->graph();
Expr loss;
auto numGroups = embeddingFactorMapping_->getNumGroups();
for (size_t g = 0; g < numGroups; g++) {

View File

@ -75,7 +75,7 @@ public:
: logits_(std::move(logits)), embeddingFactorMapping_(embeddingFactorMapping) {}
Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
Expr applyLossFunction(Expr indices, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const;
Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const;
void assign(const Logits& other) {
//ABORT_IF(!empty() && getNumFactors() != other.getNumFactors(),
// "Logits assignment cannot change number of factors");

View File

@ -2,6 +2,7 @@
#include "graph/expression_operators.h"
#include "layers/generic.h" // for Logits (Frank's factor hack)
#include "data/types.h"
namespace marian {
@ -270,7 +271,7 @@ class LabelwiseLoss {
protected:
std::vector<int> axes_;
virtual Expr compute(Logits logits, Expr labelIndices,
virtual Expr compute(Logits logits, const Words& labels,
Expr mask = nullptr, Expr labelWeights = nullptr) = 0;
// label counts are available, reduce together with loss to obtain counts
@ -305,9 +306,9 @@ public:
LabelwiseLoss(const std::vector<int>& axes)
: axes_(axes) { }
virtual RationalLoss apply(Logits logits, Expr labelIndices,
virtual RationalLoss apply(Logits logits, const Words& labels,
Expr mask = nullptr, Expr labelWeights = nullptr) {
Expr loss = compute(logits, labelIndices, mask, labelWeights);
Expr loss = compute(logits, labels, mask, labelWeights);
if(mask)
return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
@ -332,10 +333,10 @@ public:
protected:
float labelSmoothing_; // interpolation factor for label smoothing, see below
virtual Expr compute(Logits logits, Expr labelIndices,
virtual Expr compute(Logits logits, const Words& labels,
Expr mask = nullptr, Expr labelWeights = nullptr) override {
// logits may be factored; in that case, the getLoss() function computes one loss for each, and sums them up
auto ce = logits.applyLossFunction(labelIndices, [&](Expr logits, Expr indices) {
auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
Expr ce = cross_entropy(logits, indices);
if (labelSmoothing_ > 0) {
// ce = -sum_i y^_i log y_i(h)
@ -368,9 +369,9 @@ public:
// sentence-wise CE, hence reduce only over time axis. CE reduces over last axis (-1)
RescorerLoss() : CrossEntropyLoss(/*axes=*/{-3}, /*smoothing=*/0.f) {}
virtual RationalLoss apply(Logits logits, Expr labelIndices,
virtual RationalLoss apply(Logits logits, const Words& labels,
Expr mask = nullptr, Expr labelWeights = nullptr) override {
auto ce = CrossEntropyLoss::apply(logits, labelIndices, mask, labelWeights);
auto ce = CrossEntropyLoss::apply(logits, labels, mask, labelWeights);
return RationalLoss(ce.loss(), ce.count());
}
};

View File

@ -295,7 +295,7 @@ public:
// Filled externally, for BERT these are NextSentence prediction labels
const auto& classLabels = (*batch)[batchIndex_]->data();
state->setTargetIndices(graph->indices(toWordIndexVector(classLabels)));
state->setTargetWords(classLabels);
return state;
}
@ -320,8 +320,8 @@ public:
auto context = encoderStates[0]->getContext();
auto bertMaskedPositions = graph->indices(bertBatch->bertMaskedPositions()); // positions in batch of masked entries
auto bertMaskedWords = graph->indices(toWordIndexVector(bertBatch->bertMaskedWords())); // vocab ids of entries that have been masked
auto bertMaskedPositions = graph->indices(bertBatch->bertMaskedPositions()); // positions in batch of masked entries
const auto& bertMaskedWords = bertBatch->bertMaskedWords(); // vocab ids of entries that have been masked
int dimModel = context->shape()[-1];
int dimBatch = context->shape()[-2];
@ -366,7 +366,7 @@ public:
auto state = New<ClassifierState>();
state->setLogProbs(logits);
state->setTargetIndices(bertMaskedWords);
state->setTargetWords(bertMaskedWords);
return state;
}

View File

@ -69,7 +69,7 @@ public:
// @TODO: adapt to multi-objective training with multiple decoders
auto partialLoss = loss_->apply(state->getLogProbs().getLogits(),
state->getTargetIndices(),
state->getTargetWords(),
state->getTargetMask(),
weights);
multiLoss->push_back(partialLoss);
@ -119,7 +119,7 @@ public:
Ptr<MultiRationalLoss> multiLoss = newMultiLoss(options_);
for(int i = 0; i < states.size(); ++i) {
auto partialLoss = loss_->apply(states[i]->getLogProbs(),
states[i]->getTargetIndices(),
states[i]->getTargetWords(),
/*mask=*/nullptr,
/*weights=*/nullptr);
multiLoss->push_back(partialLoss);

View File

@ -71,18 +71,18 @@ public:
Expr y, yMask; std::tie
(y, yMask) = embedding_[batchIndex_]->apply(subBatch);
Expr yData;
if(shortlist_) {
yData = graph->indices(shortlist_->mappedIndices());
} else {
yData = graph->indices(toWordIndexVector(subBatch->data()));
}
const Words& data =
/*if*/ (shortlist_) ?
shortlist_->mappedIndices()
/*else*/ :
subBatch->data();
Expr yData = graph->indices(toWordIndexVector(data));
auto yShifted = shift(y, {1, 0, 0});
state->setTargetEmbeddings(yShifted);
state->setTargetMask(yMask);
state->setTargetIndices(yData);
state->setTargetWords(data);
}
virtual void embeddingsFromPrediction(Ptr<ExpressionGraph> graph,

View File

@ -179,7 +179,7 @@ Ptr<DecoderState> EncoderDecoder::stepAll(Ptr<ExpressionGraph> graph,
decoders_[0]->embeddingsFromBatch(graph, state, batch);
auto nextState = decoders_[0]->step(graph, state);
nextState->setTargetMask(state->getTargetMask());
nextState->setTargetIndices(state->getTargetIndices());
nextState->setTargetWords(state->getTargetWords());
return nextState;
}

View File

@ -36,7 +36,7 @@ protected:
Expr targetEmbeddings_;
Expr targetMask_;
Expr targetIndices_;
Words targetWords_;
// Keep track of current target token position during translation
size_t position_{0};
@ -76,11 +76,8 @@ public:
targetEmbeddings_ = targetEmbeddings;
}
virtual Expr getTargetIndices() const { return targetIndices_; };
virtual void setTargetIndices(Expr targetIndices) {
targetIndices_ = targetIndices;
}
virtual const Words& getTargetWords() const { return targetWords_; };
virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }
virtual Expr getTargetMask() const { return targetMask_; };
@ -112,15 +109,14 @@ private:
Ptr<data::CorpusBatch> batch_;
Expr targetMask_;
Expr targetIndices_;
Words targetWords_;
public:
virtual Expr getLogProbs() const { return logProbs_; }
virtual void setLogProbs(Expr logProbs) { logProbs_ = logProbs; }
virtual Expr getTargetIndices() const { return targetIndices_; };
virtual void setTargetIndices(Expr targetIndices) { targetIndices_ = targetIndices; }
virtual const Words& getTargetWords() const { return targetWords_; };
virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }
virtual Expr getTargetMask() const { return targetMask_; };

View File

@ -701,6 +701,7 @@
<ClInclude Include="..\src\common\timer.h" />
<ClInclude Include="..\src\common\types.h" />
<ClInclude Include="..\src\common\version.h" />
<ClInclude Include="..\src\data\vocab_base.h" />
<ClInclude Include="..\src\examples\mnist\dataset.h" />
<ClInclude Include="..\src\examples\mnist\model.h" />
<ClInclude Include="..\src\examples\mnist\model_lenet.h" />

View File

@ -1528,6 +1528,9 @@
<ClInclude Include="..\src\models\bert.h">
<Filter>models</Filter>
</ClInclude>
<ClInclude Include="..\src\data\vocab_base.h">
<Filter>data</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="3rd_party">