ModelBase::build() now returns an Expr, not a RationalLoss

This commit is contained in:
Frank Seide 2019-02-22 13:21:45 -08:00
parent 03806ff60f
commit 23ece0040a
12 changed files with 97 additions and 85 deletions

28
src/examples/mnist/model.h Normal file → Executable file
View File

@ -36,27 +36,23 @@ public:
// Define a top-level node for training
// use CE loss
auto loss = sum(cross_entropy(top->loss(), labels), /*axis =*/ 0);
auto loss = sum(cross_entropy(top, labels), /*axis =*/ 0);
auto multiLoss = New<SumMultiRationalLoss>();
multiLoss->push_back({loss, (float)vLabels.size()});
return multiLoss;
}
};
class MNISTLogsoftmax : public CostBase {
class MNISTLogsoftmax : public LogProbBase {
public:
MNISTLogsoftmax() {}
Ptr<MultiRationalLoss> apply(Ptr<ModelBase> model,
Expr apply(Ptr<ModelBase> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto top = model->build(graph, batch, clearGraph);
// @TODO: simplify this
auto multiLoss = New<SumMultiRationalLoss>();
multiLoss->push_back({logsoftmax(top->loss()), top->count()});
return multiLoss;
return logsoftmax(top);
}
};
@ -68,14 +64,11 @@ public:
MnistFeedForwardNet(Ptr<Options> options, Args... args)
: options_(options), inference_(options->get<bool>("inference", false)) {}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool /*clean*/ = false) override {
auto loss = construct(graph, batch, inference_); // @TODO: unify nomenclature, e.g. rather use apply
auto count = graph->constant({(int)batch->size(), 1}, inits::from_value(1.f));
return New<RationalLoss>(loss, count);
return apply(graph, batch, inference_);
}
void load(Ptr<ExpressionGraph> /*graph*/, const std::string& /*name*/, bool) override {
@ -103,8 +96,7 @@ protected:
bool inference_{false};
/**
* @brief Constructs an expression graph representing a feed-forward
* classifier.
* @brief Builds an expression graph representing a feed-forward classifier.
*
* @param dims number of nodes in each layer of the feed-forward classifier
* @param batch a batch of training or testing examples
@ -112,9 +104,9 @@ protected:
*
* @return a shared pointer to the newly constructed expression graph
*/
virtual Expr construct(Ptr<ExpressionGraph> g,
Ptr<data::Batch> batch,
bool /*inference*/ = false) {
virtual Expr apply(Ptr<ExpressionGraph> g,
Ptr<data::Batch> batch,
bool /*inference*/ = false) {
const std::vector<int> dims = {784, 2048, 2048, 10};
// Start with an empty expression graph

View File

@ -12,11 +12,11 @@ using namespace marian;
namespace marian {
class MNISTAccuracyValidator : public Validator<data::MNISTData> {
class MNISTAccuracyValidator : public Validator<data::MNISTData, models::ModelBase> {
public:
MNISTAccuracyValidator(Ptr<Options> options) : Validator(std::vector<Ptr<Vocab>>(), options, false) {
createBatchGenerator(/*isTranslating=*/false);
builder_ = models::createModelFromOptions(options, models::usage::scoring);
builder_ = models::createModelFromOptions(options, models::usage::translation);
}
virtual void keepBest(const std::vector<Ptr<ExpressionGraph>>& graphs) override {
@ -35,7 +35,7 @@ protected:
graphs[0]->forward();
std::vector<float> scores;
probs->loss(scores);
probs->val()->get(scores);
correct += countCorrect(scores, batch->labels());
samples += batch->size();

View File

@ -159,17 +159,25 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override { model_->clear(graph); };
};
class LogProbBase {
public:
virtual Expr apply(Ptr<ModelBase> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) = 0;
};
// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth?
// Beam search uses it for the former meaning, while 'marian score' and validation in the latter.
// This class is for the former use. The latter is done using Trainer.
class Scorer : public ModelBase {
protected:
Ptr<ModelBase> model_;
Ptr<CostBase> cost_;
Ptr<LogProbBase> logProb_;
public:
Scorer(Ptr<ModelBase> model, Ptr<CostBase> cost)
: model_(model), cost_(cost) {}
Scorer(Ptr<ModelBase> model, Ptr<LogProbBase> cost)
: model_(model), logProb_(cost) {}
Ptr<ModelBase> getModel() { return model_; }
@ -185,10 +193,10 @@ public:
model_->save(graph, name, saveTranslatorConfig);
}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
return cost_->apply(model_, graph, batch, clearGraph);
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
return logProb_->apply(model_, graph, batch, clearGraph);
};
virtual void clear(Ptr<ExpressionGraph> graph) override { model_->clear(graph); };
@ -259,9 +267,9 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override { encdec_->clear(graph); }
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
return build(graph, corpusBatch, clearGraph);
}
@ -282,9 +290,9 @@ public:
return cost_->apply(nextState);
}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> /*graph*/,
Ptr<data::CorpusBatch> /*batch*/,
bool /*clearGraph*/ = true) override {
virtual Expr build(Ptr<ExpressionGraph> /*graph*/,
Ptr<data::CorpusBatch> /*batch*/,
bool /*clearGraph*/ = true) override {
ABORT("Wrong wrapper. Use models::Trainer or models::Scorer");
return nullptr;
}

View File

@ -41,13 +41,13 @@ public:
virtual std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, bool) = 0;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override = 0;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override = 0;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) = 0;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) = 0;
virtual Ptr<Options> getOptions() = 0;
};
@ -206,17 +206,17 @@ public:
return classifierStates;
}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) override {
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) override {
auto states = apply(graph, batch, clearGraph);
// returns raw logits
return New<RationalLoss>(states[0]->getLogProbs(), nullptr); // @TODO: Check if this is actually used
return states[0]->getLogProbs();
}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
return build(graph, corpusBatch, clearGraph);
}

10
src/models/encoder_decoder.cpp Normal file → Executable file
View File

@ -196,16 +196,16 @@ Ptr<DecoderState> EncoderDecoder::stepAll(Ptr<ExpressionGraph> graph,
return nextState;
}
Ptr<RationalLoss> EncoderDecoder::build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph) {
Expr EncoderDecoder::build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph) {
auto state = stepAll(graph, batch, clearGraph);
// returns raw logits
return New<RationalLoss>(state->getLogProbs(), state->getTargetMask()); // @TODO: hacky hack hack
return state->getLogProbs(); // , state->getTargetMask()); // @TODO: hacky hack hack
}
Ptr<RationalLoss> EncoderDecoder::build(Ptr<ExpressionGraph> graph,
Expr EncoderDecoder::build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph) {
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);

24
src/models/encoder_decoder.h Normal file → Executable file
View File

@ -28,13 +28,13 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override = 0;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override = 0;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override = 0;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) = 0;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) = 0;
virtual Ptr<DecoderState> startState(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch) = 0;
@ -156,13 +156,13 @@ public:
Ptr<data::CorpusBatch> batch,
bool clearGraph = true);
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) override;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) override;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override;
};
} // namespace marian

View File

@ -28,9 +28,9 @@ public:
bool saveTranslatorConfig = false)
= 0;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true)
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true)
= 0;
virtual void clear(Ptr<ExpressionGraph> graph) = 0;

View File

@ -19,11 +19,11 @@ using namespace data;
class Rescorer {
private:
Ptr<models::ModelBase> builder_;
Ptr<models::CriterionBase> builder_;
public:
Rescorer(Ptr<Options> options)
: builder_(models::createModelFromOptions(options, models::usage::scoring)) {}
: builder_(models::createCriterionFromOptions(options, models::usage::scoring)) {}
void load(Ptr<ExpressionGraph> graph, const std::string& modelFile) {
builder_->load(graph, modelFile);
@ -34,7 +34,7 @@ public:
}
data::SoftAlignment getAlignment() {
auto model = std::static_pointer_cast<models::Scorer>(builder_)->getModel();
auto model = std::static_pointer_cast<models::Trainer>(builder_)->getModel();
return std::static_pointer_cast<EncoderDecoderBase>(model)->getAlignment();
}
};

4
src/training/validator.cpp Normal file → Executable file
View File

@ -2,10 +2,10 @@
namespace marian {
std::vector<Ptr<Validator<data::Corpus>>> Validators(
std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> Validators(
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> config) {
std::vector<Ptr<Validator<data::Corpus>>> validators;
std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> validators;
auto validMetrics = config->get<std::vector<std::string>>("valid-metrics");

View File

@ -56,7 +56,7 @@ public:
}
};
template <class DataSet>
template <class DataSet, class BuilderType>
class Validator : public ValidatorBase {
public:
Validator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options, bool lowerIsBetter = true)
@ -123,7 +123,7 @@ public:
protected:
std::vector<Ptr<Vocab>> vocabs_;
Ptr<Options> options_;
Ptr<models::ModelBase> builder_;
Ptr<BuilderType> builder_;
Ptr<data::BatchGenerator<DataSet>> batchGenerator_;
virtual float validateBG(const std::vector<Ptr<ExpressionGraph>>&)
@ -148,7 +148,7 @@ protected:
}
};
class CrossEntropyValidator : public Validator<data::Corpus> {
class CrossEntropyValidator : public Validator<data::Corpus, models::CriterionBase> {
public:
CrossEntropyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options) {
@ -159,7 +159,7 @@ public:
opts->merge(options);
opts->set("inference", true);
opts->set("cost-type", "ce-sum");
builder_ = models::createModelFromOptions(opts, models::usage::scoring);
builder_ = models::createCriterionFromOptions(opts, models::usage::scoring);
}
std::string type() override { return options_->get<std::string>("cost-type"); }
@ -180,7 +180,7 @@ protected:
for(auto batch : *batchGenerator_) {
auto task = [=, &loss, &samples](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local auto builder = models::createModelFromOptions(options_, models::usage::scoring);
thread_local auto builder = models::createCriterionFromOptions(options_, models::usage::scoring);
if(!graph) {
graph = graphs[id % graphs.size()];
@ -215,8 +215,8 @@ protected:
}
};
// Used for validating with classifiers. Compute prediction accuary versus groundtruth for a set of classes
class AccuracyValidator : public Validator<data::Corpus> {
// Used for validating with classifiers. Compute prediction accuracy versus ground truth for a set of classes
class AccuracyValidator : public Validator<data::Corpus, models::ModelBase> {
public:
AccuracyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, /*lowerIsBetter=*/false) {
@ -263,7 +263,7 @@ protected:
// correct += correct->scalar<IndexType>();
builder->clear(graph);
Expr logits = builder->build(graph, batch)->loss();
Expr logits = builder->build(graph, batch);
graph->forward();
std::vector<float> vLogits;
@ -305,7 +305,7 @@ protected:
}
};
class BertAccuracyValidator : public Validator<data::Corpus> {
class BertAccuracyValidator : public Validator<data::Corpus, models::ModelBase> {
private:
bool evalMaskedLM_{true};
@ -415,7 +415,7 @@ protected:
};
class ScriptValidator : public Validator<data::Corpus> {
class ScriptValidator : public Validator<data::Corpus, models::ModelBase> {
public:
ScriptValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, false) {
@ -446,7 +446,7 @@ protected:
}
};
class TranslationValidator : public Validator<data::Corpus> {
class TranslationValidator : public Validator<data::Corpus, models::ModelBase> {
public:
TranslationValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, false),
@ -578,7 +578,7 @@ protected:
};
// @TODO: combine with TranslationValidator (above) to avoid code duplication
class BleuValidator : public Validator<data::Corpus> {
class BleuValidator : public Validator<data::Corpus, models::ModelBase> {
private:
bool detok_{false};
@ -844,7 +844,7 @@ protected:
*
* @return Vector of validator objects
*/
std::vector<Ptr<Validator<data::Corpus>>> Validators(
std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> Validators(
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> config);
} // namespace marian

3
vs/Marian.vcxproj Normal file → Executable file
View File

@ -894,10 +894,13 @@
<ClInclude Include="..\src\layers\word2vec_reader.h" />
<ClInclude Include="..\src\microsoft\quicksand.h" />
<ClInclude Include="..\src\models\amun.h" />
<ClInclude Include="..\src\models\bert.h" />
<ClInclude Include="..\src\models\char_s2s.h" />
<ClInclude Include="..\src\models\classifier.h" />
<ClInclude Include="..\src\models\costs.h" />
<ClInclude Include="..\src\models\decoder.h" />
<ClInclude Include="..\src\models\encoder.h" />
<ClInclude Include="..\src\models\encoder_classifier.h" />
<ClInclude Include="..\src\models\encoder_decoder.h" />
<ClInclude Include="..\src\models\model_base.h" />
<ClInclude Include="..\src\models\model_factory.h" />

9
vs/Marian.vcxproj.filters Normal file → Executable file
View File

@ -1517,6 +1517,15 @@
<ClInclude Include="..\src\examples\mnist\validator.h">
<Filter>examples\mnist</Filter>
</ClInclude>
<ClInclude Include="..\src\models\bert.h">
<Filter>models</Filter>
</ClInclude>
<ClInclude Include="..\src\models\classifier.h">
<Filter>models</Filter>
</ClInclude>
<ClInclude Include="..\src\models\encoder_classifier.h">
<Filter>models</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="3rd_party">