merged from fseide/criterionfunction

This commit is contained in:
Frank Seide 2019-03-13 09:45:47 -07:00
commit f6304a43e8
26 changed files with 307 additions and 242 deletions

34
src/examples/mnist/model.h Normal file → Executable file
View File

@ -17,11 +17,11 @@ namespace models {
// @TODO: looking at this file, simplify the new RationalLoss idea. Here it gets too complicated
class MNISTCrossEntropyCost : public CostBase {
class MNISTCrossEntropyCost : public ICost {
public:
MNISTCrossEntropyCost() {}
Ptr<MultiRationalLoss> apply(Ptr<ModelBase> model,
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
@ -36,31 +36,27 @@ public:
// Define a top-level node for training
// use CE loss
auto loss = sum(cross_entropy(top->loss(), labels), /*axis =*/ 0);
auto loss = sum(cross_entropy(top, labels), /*axis =*/ 0);
auto multiLoss = New<SumMultiRationalLoss>();
multiLoss->push_back({loss, (float)vLabels.size()});
return multiLoss;
}
};
class MNISTLogsoftmax : public CostBase {
class MNISTLogsoftmax : public ILogProb {
public:
MNISTLogsoftmax() {}
Ptr<MultiRationalLoss> apply(Ptr<ModelBase> model,
Expr apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto top = model->build(graph, batch, clearGraph);
// @TODO: simplify this
auto multiLoss = New<SumMultiRationalLoss>();
multiLoss->push_back({logsoftmax(top->loss()), top->count()});
return multiLoss;
return logsoftmax(top);
}
};
class MnistFeedForwardNet : public ModelBase {
class MnistFeedForwardNet : public IModel {
public:
typedef data::MNISTData dataset_type;
@ -68,14 +64,11 @@ public:
MnistFeedForwardNet(Ptr<Options> options, Args... args)
: options_(options), inference_(options->get<bool>("inference", false)) {}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool /*clean*/ = false) override {
auto loss = construct(graph, batch, inference_); // @TODO: unify nomenclature, e.g. rather use apply
auto count = graph->constant({(int)batch->size(), 1}, inits::from_value(1.f));
return New<RationalLoss>(loss, count);
return apply(graph, batch, inference_);
}
void load(Ptr<ExpressionGraph> /*graph*/, const std::string& /*name*/, bool) override {
@ -103,8 +96,7 @@ protected:
bool inference_{false};
/**
* @brief Constructs an expression graph representing a feed-forward
* classifier.
* @brief Builds an expression graph representing a feed-forward classifier.
*
* @param dims number of nodes in each layer of the feed-forward classifier
* @param batch a batch of training or testing examples
@ -112,9 +104,9 @@ protected:
*
* @return a shared pointer to the newly constructed expression graph
*/
virtual Expr construct(Ptr<ExpressionGraph> g,
Ptr<data::Batch> batch,
bool /*inference*/ = false) {
virtual Expr apply(Ptr<ExpressionGraph> g,
Ptr<data::Batch> batch,
bool /*inference*/ = false) {
const std::vector<int> dims = {784, 2048, 2048, 10};
// Start with an empty expression graph

View File

@ -15,9 +15,9 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override { graph->clear(); };
protected:
virtual Expr construct(Ptr<ExpressionGraph> g,
Ptr<data::Batch> batch,
bool inference = false) override {
virtual Expr apply(Ptr<ExpressionGraph> g,
Ptr<data::Batch> batch,
bool inference = false) override {
const std::vector<int> dims = {784, 128, 10};
// Start with an empty expression graph

6
src/examples/mnist/validator.h Normal file → Executable file
View File

@ -12,11 +12,11 @@ using namespace marian;
namespace marian {
class MNISTAccuracyValidator : public Validator<data::MNISTData> {
class MNISTAccuracyValidator : public Validator<data::MNISTData, models::IModel> {
public:
MNISTAccuracyValidator(Ptr<Options> options) : Validator(std::vector<Ptr<Vocab>>(), options, false) {
createBatchGenerator(/*isTranslating=*/false);
builder_ = models::from_options(options, models::usage::scoring);
builder_ = models::createModelFromOptions(options, models::usage::translation);
}
virtual void keepBest(const std::vector<Ptr<ExpressionGraph>>& graphs) override {
@ -35,7 +35,7 @@ protected:
graphs[0]->forward();
std::vector<float> scores;
probs->loss(scores);
probs->val()->get(scores);
correct += countCorrect(scores, batch->labels());
samples += batch->size();

View File

@ -72,7 +72,7 @@ public:
std::cerr << modelOpts->str() << std::flush;
auto encdec = models::from_options(modelOpts, models::usage::translation);
auto encdec = models::createModelFromOptions(modelOpts, models::usage::translation);
if(io::isBin(models[i]) && ptrs_[i] != nullptr) {
// if file ends in *.bin and has been mapped by QuickSAND

128
src/models/costs.h Normal file → Executable file
View File

@ -19,15 +19,15 @@ namespace models {
// Other functions return RationalLoss directly without Ptr<...>, but also
// they do not need polymorphism here.
class CostBase {
class ICost {
public:
virtual Ptr<MultiRationalLoss> apply(Ptr<ModelBase> model,
Ptr<ExpressionGraph> graph,
virtual Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
Ptr<data::Batch> batch,
bool clearGraph = true) = 0;
};
class EncoderDecoderCE : public CostBase {
class EncoderDecoderCECost : public ICost {
protected:
Ptr<Options> options_;
@ -39,7 +39,7 @@ protected:
Ptr<WeightingBase> weighter_;
public:
EncoderDecoderCE(Ptr<Options> options)
EncoderDecoderCECost(Ptr<Options> options)
: options_(options), inference_(options->get<bool>("inference", false)) {
loss_ = newLoss(options_, inference_);
@ -51,7 +51,7 @@ public:
weighter_ = WeightingFactory(options_);
}
Ptr<MultiRationalLoss> apply(Ptr<ModelBase> model,
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
@ -89,7 +89,7 @@ public:
};
// Wraps an EncoderClassifier so it can produce a cost from raw logits. @TODO: Needs refactoring
class EncoderClassifierCE : public CostBase {
class EncoderClassifierCECost : public ICost {
protected:
Ptr<Options> options_;
bool inference_{false};
@ -99,12 +99,12 @@ protected:
Ptr<LabelwiseLoss> loss_;
public:
EncoderClassifierCE(Ptr<Options> options)
EncoderClassifierCECost(Ptr<Options> options)
: options_(options), inference_(options->get<bool>("inference", false)) {
loss_ = newLoss(options_, inference_);
}
Ptr<MultiRationalLoss> apply(Ptr<ModelBase> model,
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
@ -127,16 +127,16 @@ public:
}
};
class Trainer : public ModelBase {
class Trainer : public ICriterionFunction {
protected:
Ptr<ModelBase> model_;
Ptr<CostBase> cost_;
Ptr<IModel> model_;
Ptr<ICost> cost_;
public:
Trainer(Ptr<ModelBase> model, Ptr<CostBase> cost)
Trainer(Ptr<IModel> model, Ptr<ICost> cost)
: model_(model), cost_(cost) {}
Ptr<ModelBase> getModel() { return model_; }
Ptr<IModel> getModel() { return model_; }
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
@ -159,14 +159,55 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override { model_->clear(graph); };
};
typedef Trainer Scorer;
class ILogProb {
public:
virtual Expr apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) = 0;
};
class CostStep {
// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth?
// Beam search uses it for the former meaning, while 'marian score' and validation in the latter.
// This class is for the former use. The latter is done using Trainer.
class Scorer : public IModel {
protected:
Ptr<IModel> model_;
Ptr<ILogProb> logProb_;
public:
Scorer(Ptr<IModel> model, Ptr<ILogProb> cost)
: model_(model), logProb_(cost) {}
Ptr<IModel> getModel() { return model_; }
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override {
model_->load(graph, name, markedReloaded);
};
virtual void save(Ptr<ExpressionGraph> graph,
const std::string& name,
bool saveTranslatorConfig = false) override {
model_->save(graph, name, saveTranslatorConfig);
}
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
return logProb_->apply(model_, graph, batch, clearGraph);
};
virtual void clear(Ptr<ExpressionGraph> graph) override { model_->clear(graph); };
};
class ILogProbStep {
public:
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) = 0;
};
class LogSoftmaxStep : public CostStep {
class LogSoftmaxStep : public ILogProbStep {
public:
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
@ -182,7 +223,7 @@ public:
// Gumbel-max noising for sampling during beam-search
// Seems to work well enough with beam-size=1. Turn on
// with --output-sampling during translation with marian-decoder
class GumbelSoftmaxStep : public CostStep {
class GumbelSoftmaxStep : public ILogProbStep {
public:
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
auto logits = state->getLogProbs();
@ -194,16 +235,16 @@ public:
}
};
// class to wrap an EncoderDecoderBase and a CostStep that are executed in sequence,
// class to wrap an EncoderDecoderBase and a ILogProbStep that are executed in sequence,
// wrapped again in the EncoderDecoderBase interface
// @TODO: seems we are conflating an interface defition with its implementation?
class Stepwise : public EncoderDecoderBase {
protected:
Ptr<EncoderDecoderBase> encdec_;
Ptr<CostStep> cost_;
Ptr<ILogProbStep> cost_;
public:
Stepwise(Ptr<EncoderDecoderBase> encdec, Ptr<CostStep> cost)
Stepwise(Ptr<EncoderDecoderBase> encdec, Ptr<ILogProbStep> cost)
: encdec_(encdec), cost_(cost) {}
virtual void load(Ptr<ExpressionGraph> graph,
@ -226,9 +267,9 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override { encdec_->clear(graph); }
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
return build(graph, corpusBatch, clearGraph);
}
@ -249,9 +290,9 @@ public:
return cost_->apply(nextState);
}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> /*graph*/,
Ptr<data::CorpusBatch> /*batch*/,
bool /*clearGraph*/ = true) override {
virtual Expr build(Ptr<ExpressionGraph> /*graph*/,
Ptr<data::CorpusBatch> /*batch*/,
bool /*clearGraph*/ = true) override {
ABORT("Wrong wrapper. Use models::Trainer or models::Scorer");
return nullptr;
}
@ -270,38 +311,5 @@ public:
virtual data::SoftAlignment getAlignment() override { return encdec_->getAlignment(); }
};
inline Ptr<ModelBase> add_cost(Ptr<EncoderDecoder> encdec,
Ptr<Options> options) {
switch(options->get<usage>("usage", usage::raw)) {
case usage::training:
return New<Trainer>(encdec, New<EncoderDecoderCE>(options));
case usage::scoring:
return New<Scorer>(encdec, New<EncoderDecoderCE>(options));
case usage::translation:
if(options->get<bool>("output-sampling", false))
return New<Stepwise>(encdec, New<GumbelSoftmaxStep>());
else
return New<Stepwise>(encdec, New<LogSoftmaxStep>());
case usage::raw:
default:
return encdec;
}
}
inline Ptr<ModelBase> add_cost(Ptr<EncoderClassifier> enccls,
Ptr<Options> options) {
switch(options->get<usage>("usage", usage::raw)) {
case usage::training:
return New<Trainer>(enccls, New<EncoderClassifierCE>(options));
case usage::scoring:
return New<Scorer>(enccls, New<EncoderClassifierCE>(options));
case usage::translation:
ABORT("Classifier cannot be used for translation");
case usage::raw:
default:
return enccls;
}
}
} // namespace models
} // namespace marian

View File

@ -17,7 +17,7 @@ namespace marian {
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/classifier
* multi-objective training.
*/
class EncoderClassifierBase : public models::ModelBase {
class EncoderClassifierBase : public models::IModel {
public:
virtual ~EncoderClassifierBase() {}
@ -41,13 +41,13 @@ public:
virtual std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, bool) = 0;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override = 0;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override = 0;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) = 0;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) = 0;
virtual Ptr<Options> getOptions() = 0;
};
@ -206,17 +206,17 @@ public:
return classifierStates;
}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) override {
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) override {
auto states = apply(graph, batch, clearGraph);
// returns raw logits
return New<RationalLoss>(states[0]->getLogProbs(), nullptr); // @TODO: Check if this is actually used
return states[0]->getLogProbs();
}
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
return build(graph, corpusBatch, clearGraph);
}

View File

@ -195,16 +195,16 @@ Ptr<DecoderState> EncoderDecoder::stepAll(Ptr<ExpressionGraph> graph,
return nextState;
}
Ptr<RationalLoss> EncoderDecoder::build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph) {
Expr EncoderDecoder::build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph) {
auto state = stepAll(graph, batch, clearGraph);
// returns raw logits
return New<RationalLoss>(state->getLogProbs(), state->getTargetMask()); // @TODO: hacky hack hack
return state->getLogProbs();
}
Ptr<RationalLoss> EncoderDecoder::build(Ptr<ExpressionGraph> graph,
Expr EncoderDecoder::build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph) {
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);

View File

@ -9,7 +9,7 @@
namespace marian {
class EncoderDecoderBase : public models::ModelBase {
class EncoderDecoderBase : public models::IModel {
public:
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
@ -28,13 +28,13 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override = 0;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override = 0;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override = 0;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) = 0;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) = 0;
virtual Ptr<DecoderState> startState(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch) = 0;
@ -156,13 +156,13 @@ public:
Ptr<data::CorpusBatch> batch,
bool clearGraph = true);
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) override;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true) override;
virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override;
};
} // namespace marian

24
src/models/model_base.h Normal file → Executable file
View File

@ -16,7 +16,29 @@ YAML_REGISTER_TYPE(marian::models::usage, int)
namespace marian {
namespace models {
class ModelBase {
// model = input -> predictions
class IModel {
public:
virtual void load(Ptr<ExpressionGraph>,
const std::string&,
bool markReloaded = true)
= 0;
virtual void save(Ptr<ExpressionGraph>,
const std::string&,
bool saveTranslatorConfig = false)
= 0;
virtual Expr build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true)
= 0;
virtual void clear(Ptr<ExpressionGraph> graph) = 0;
};
// criterion = (input, reference) -> loss
// @TODO: Is there a better name?
class ICriterionFunction {
public:
virtual void load(Ptr<ExpressionGraph>,
const std::string&,

View File

@ -60,7 +60,7 @@ Ptr<ClassifierBase> ClassifierFactory::construct(Ptr<ExpressionGraph> /*graph*/)
ABORT("Unknown classifier type");
}
Ptr<ModelBase> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) {
Ptr<IModel> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) {
Ptr<EncoderDecoder> encdec;
if(options_->get<std::string>("type") == "amun")
@ -77,10 +77,10 @@ Ptr<ModelBase> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) {
for(auto& df : decoders_)
encdec->push_back(df(options_).construct(graph));
return add_cost(encdec, options_);
return encdec;
}
Ptr<ModelBase> EncoderClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
Ptr<IModel> EncoderClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
Ptr<EncoderClassifier> enccls;
if(options_->get<std::string>("type") == "bert") {
enccls = New<BertEncoderClassifier>(options_);
@ -96,22 +96,22 @@ Ptr<ModelBase> EncoderClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
for(auto& cf : classifiers_)
enccls->push_back(cf(options_).construct(graph));
return add_cost(enccls, options_);
return enccls;
}
Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> options) {
Ptr<ExpressionGraph> graph = nullptr; // graph unknown at this stage
// clang-format off
if(type == "s2s" || type == "amun" || type == "nematus") {
return models::encoder_decoder()(options)
("usage", use)
("original-type", type)
.push_back(models::encoder()("type", "s2s"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
.push_back(models::encoder()("type", "s2s"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
}
if(type == "transformer") {
else if(type == "transformer") {
return models::encoder_decoder()(options)
("usage", use)
.push_back(models::encoder()("type", "transformer"))
@ -119,16 +119,16 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
.construct(graph);
}
if(type == "transformer_s2s") {
else if(type == "transformer_s2s") {
return models::encoder_decoder()(options)
("usage", use)
("original-type", type)
.push_back(models::encoder()("type", "transformer"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
.push_back(models::encoder()("type", "transformer"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
}
if(type == "lm") {
else if(type == "lm") {
auto idx = options->has("index") ? options->get<size_t>("index") : 0;
std::vector<int> dimVocabs = options->get<std::vector<int>>("dim-vocabs");
int vocab = dimVocabs[0];
@ -139,13 +139,13 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
("usage", use)
("type", "s2s")
("original-type", type)
.push_back(models::decoder()
("index", idx)
("dim-vocabs", dimVocabs))
.construct(graph);
.push_back(models::decoder()
("index", idx)
("dim-vocabs", dimVocabs))
.construct(graph);
}
if(type == "multi-s2s") {
else if(type == "multi-s2s") {
size_t numEncoders = 2;
auto ms2sFactory = models::encoder_decoder()(options)
("usage", use)
@ -162,7 +162,7 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
return ms2sFactory.construct(graph);
}
if(type == "shared-multi-s2s") {
else if(type == "shared-multi-s2s") {
size_t numEncoders = 2;
auto ms2sFactory = models::encoder_decoder()(options)
("usage", use)
@ -179,7 +179,7 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
return ms2sFactory.construct(graph);
}
if(type == "multi-transformer") {
else if(type == "multi-transformer") {
size_t numEncoders = 2;
auto mtransFactory = models::encoder_decoder()(options)
("usage", use)
@ -195,7 +195,7 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
return mtransFactory.construct(graph);
}
if(type == "shared-multi-transformer") {
else if(type == "shared-multi-transformer") {
size_t numEncoders = 2;
auto mtransFactory = models::encoder_decoder()(options)
("usage", use)
@ -211,7 +211,7 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
return mtransFactory.construct(graph);
}
if(type == "lm-transformer") {
else if(type == "lm-transformer") {
auto idx = options->has("index") ? options->get<size_t>("index") : 0;
std::vector<int> dimVocabs = options->get<std::vector<int>>("dim-vocabs");
int vocab = dimVocabs[0];
@ -222,85 +222,120 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
("usage", use)
("type", "transformer")
("original-type", type)
.push_back(models::decoder()
("index", idx)
("dim-vocabs", dimVocabs))
.construct(graph);
.push_back(models::decoder()
("index", idx)
("dim-vocabs", dimVocabs))
.construct(graph);
}
if(type == "bert") { // for full BERT training
else if(type == "bert") { // for full BERT training
return models::encoder_classifier()(options) //
("original-type", "bert") // so we can query this
("usage", use) //
.push_back(models::encoder() //
("type", "bert-encoder") // close to original transformer encoder
("index", 0)) //
("type", "bert-encoder") // close to original transformer encoder
("index", 0)) //
.push_back(models::classifier() //
("prefix", "masked-lm") // prefix for parameter names
("type", "bert-masked-lm") //
("index", 0)) // multi-task learning with MaskedLM
("prefix", "masked-lm") // prefix for parameter names
("type", "bert-masked-lm") //
("index", 0)) // multi-task learning with MaskedLM
.push_back(models::classifier() //
("prefix", "next-sentence") // prefix for parameter names
("type", "bert-classifier") //
("index", 1)) // next sentence prediction
("prefix", "next-sentence") // prefix for parameter names
("type", "bert-classifier") //
("index", 1)) // next sentence prediction
.construct(graph);
}
if(type == "bert-classifier") { // for BERT fine-tuning on non-BERT classification task
else if(type == "bert-classifier") { // for BERT fine-tuning on non-BERT classification task
return models::encoder_classifier()(options) //
("original-type", "bert-classifier") // so we can query this if needed
("usage", use) //
.push_back(models::encoder() //
("type", "bert-encoder") //
("index", 0)) // close to original transformer encoder
("type", "bert-encoder") //
("index", 0)) // close to original transformer encoder
.push_back(models::classifier() //
("type", "bert-classifier") //
("index", 1)) // next sentence prediction
("type", "bert-classifier") //
("index", 1)) // next sentence prediction
.construct(graph);
}
#ifdef COMPILE_EXAMPLES
// @TODO: examples should be compiled optionally
if(type == "mnist-ffnn") {
auto mnist = New<MnistFeedForwardNet>(options);
if(use == usage::scoring)
return New<Scorer>(mnist, New<MNISTLogsoftmax>());
else if(use == usage::training)
return New<Trainer>(mnist, New<MNISTCrossEntropyCost>());
else
return mnist;
}
else if(type == "mnist-ffnn")
return New<MnistFeedForwardNet>(options);
#endif
#ifdef CUDNN
#ifdef COMPILE_EXAMPLES
if(type == "mnist-lenet") {
auto mnist = New<MnistLeNet>(options);
if(use == usage::scoring)
return New<Scorer>(mnist, New<MNISTLogsoftmax>());
else if(use == usage::training)
return New<Trainer>(mnist, New<MNISTCrossEntropyCost>());
else
return mnist;
}
else if(type == "mnist-lenet")
return New<MnistLeNet>(options);
#endif
if(type == "char-s2s") {
else if(type == "char-s2s") {
return models::encoder_decoder()(options)
("usage", use)
("original-type", type)
.push_back(models::encoder()("type", "char-s2s"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
.push_back(models::encoder()("type", "char-s2s"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
}
#endif
// clang-format on
ABORT("Unknown model type: {}", type);
else
ABORT("Unknown model type: {}", type);
}
Ptr<ModelBase> from_options(Ptr<Options> options, usage use) {
Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
std::string type = options->get<std::string>("type");
return by_type(type, use, options);
auto baseModel = createBaseModelByType(type, use, options);
// add (log)softmax if requested
if (use == usage::translation) {
if(std::dynamic_pointer_cast<EncoderDecoder>(baseModel)) {
if(options->get<bool>("output-sampling", false))
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
else
return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
}
#ifdef COMPILE_EXAMPLES
// note: 'usage::translation' here means 'inference'
else if (std::dynamic_pointer_cast<MnistFeedForwardNet>(baseModel))
return New<Scorer>(baseModel, New<MNISTLogsoftmax>());
#ifdef CUDNN
else if (std::dynamic_pointer_cast<MnistLeNet>(baseModel))
return New<Scorer>(baseModel, New<MNISTLogsoftmax>());
#endif
#endif
else
ABORT("'usage' parameter 'translation' cannot be applied to model type: {}", type);
}
else if (use == usage::raw)
return baseModel;
else
ABORT("'Usage' parameter must be 'translation' or 'raw'");
}
Ptr<ICriterionFunction> createCriterionFunctionFromOptions(Ptr<Options> options, usage use) {
std::string type = options->get<std::string>("type");
auto baseModel = createBaseModelByType(type, use, options);
// add cost function
ABORT_IF(use != usage::training && use != usage::scoring, "'Usage' parameter must be 'training' or 'scoring'");
// note: usage::scoring means "score the loss function", hence it uses a Trainer (not Scorer, which is for decoding)
// @TODO: Should we define a new class that does not compute gradients?
if (std::dynamic_pointer_cast<EncoderDecoder>(baseModel))
return New<Trainer>(baseModel, New<EncoderDecoderCECost>(options));
else if (std::dynamic_pointer_cast<EncoderClassifier>(baseModel))
return New<Trainer>(baseModel, New<EncoderClassifierCECost>(options));
#ifdef COMPILE_EXAMPLES
// @TODO: examples should be compiled optionally
else if (std::dynamic_pointer_cast<MnistFeedForwardNet>(baseModel))
return New<Trainer>(baseModel, New<MNISTCrossEntropyCost>());
#ifdef CUDNN
else if (std::dynamic_pointer_cast<MnistLeNet>(baseModel))
return New<Trainer>(baseModel, New<MNISTCrossEntropyCost>());
#endif
#endif
else
ABORT("Criterion function unknown for model type: {}", type);
}
} // namespace models

View File

@ -56,7 +56,7 @@ public:
return Accumulator<EncoderDecoderFactory>(*this);
}
virtual Ptr<ModelBase> construct(Ptr<ExpressionGraph> graph);
virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
};
typedef Accumulator<EncoderDecoderFactory> encoder_decoder;
@ -80,13 +80,15 @@ public:
return Accumulator<EncoderClassifierFactory>(*this);
}
virtual Ptr<ModelBase> construct(Ptr<ExpressionGraph> graph);
virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
};
typedef Accumulator<EncoderClassifierFactory> encoder_classifier;
Ptr<ModelBase> by_type(std::string type, usage, Ptr<Options> options);
Ptr<IModel> createBaseModelByType(std::string type, usage, Ptr<Options> options);
Ptr<ModelBase> from_options(Ptr<Options> options, usage);
Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage);
Ptr<ICriterionFunction> createCriterionFunctionFromOptions(Ptr<Options> options, usage);
} // namespace models
} // namespace marian

6
src/rescorer/rescorer.h Normal file → Executable file
View File

@ -19,11 +19,11 @@ using namespace data;
class Rescorer {
private:
Ptr<models::ModelBase> builder_;
Ptr<models::ICriterionFunction> builder_;
public:
Rescorer(Ptr<Options> options)
: builder_(models::from_options(options, models::usage::scoring)) {}
: builder_(models::createCriterionFunctionFromOptions(options, models::usage::scoring)) {}
void load(Ptr<ExpressionGraph> graph, const std::string& modelFile) {
builder_->load(graph, modelFile);
@ -34,7 +34,7 @@ public:
}
data::SoftAlignment getAlignment() {
auto model = std::static_pointer_cast<models::Scorer>(builder_)->getModel();
auto model = std::static_pointer_cast<models::Trainer>(builder_)->getModel();
return std::static_pointer_cast<EncoderDecoderBase>(model)->getAlignment();
}
};

7
src/training/communicator.cpp Normal file → Executable file
View File

@ -79,13 +79,12 @@ public:
HANDLE_MPI_ERROR(MPI_Init_thread(&argc, &argvp, MPI_THREAD_MULTIPLE, &providedThreadingMode));
MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN); // have errors reported as return codes
ABORT_IF(
providedThreadingMode < requiredThreadingMode,
"Your version of MPI does not support multi-threaded communication.");
MPI_Comm_size(MPI_COMM_WORLD, &comm_world_size_);
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank_);
ABORT_IF(comm_world_size_ > 1 && providedThreadingMode < requiredThreadingMode,
"Your version of MPI does not support multi-threaded communication.");
// patch logging pattern to include the MPI rank, so that we can associate error messages with nodes
if (numMPIProcesses() > 1) {
std::string rankStr = std::to_string(MPIWrapper::myMPIRank());

6
src/training/graph_group.h Normal file → Executable file
View File

@ -55,7 +55,7 @@ public:
*/
// @TODO: Can this be made const? It seems wrong to have a stateful method that still returns a result.
virtual Ptr<data::BatchStats> collectStats(Ptr<ExpressionGraph> graph,
Ptr<models::ModelBase> model,
Ptr<models::ICriterionFunction> model,
const std::vector<Ptr<Vocab>>& vocabs,
double multiplier = 1.) {
auto stats = New<data::BatchStats>();
@ -141,7 +141,7 @@ protected:
std::vector<size_t> devices_; // [num local GPUs]
/** Graph builders for clients (which run forward and backward passes). */
std::vector<Ptr<models::ModelBase>> clientBuilders_;
std::vector<Ptr<models::ICriterionFunction>> clientBuilders_;
/** Graphs of clients. One entry per GPU on this node. */
std::vector<Ptr<ExpressionGraph>> clientGraphs_; // [num local GPUs]
@ -161,7 +161,7 @@ public:
clientGraphs_.push_back(New<ExpressionGraph>());
clientGraphs_[i]->setDevice({ devices_[i], DeviceType::gpu });
clientGraphs_[i]->reserveWorkspaceMB(options_->get<size_t>("workspace"));
clientBuilders_.push_back(models::from_options(options_, models::usage::training));
clientBuilders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training));
}
}

4
src/training/graph_group_async.cpp Normal file → Executable file
View File

@ -23,7 +23,7 @@ AsyncGraphGroup::AsyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi)
graphs_.push_back(graph);
shardOpt_.push_back(Optimizer(options_));
builders_.push_back(models::from_options(options_, models::usage::training));
builders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training));
}
}
@ -189,7 +189,7 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) {
auto task = [this](Ptr<data::Batch> batch) {
static size_t i = 0;
thread_local Ptr<ExpressionGraph> graph;
thread_local Ptr<models::ModelBase> builder;
thread_local Ptr<models::ICriterionFunction> builder;
thread_local size_t t = 0;
thread_local size_t num_seen_words = 0;
thread_local size_t num_seen_sentences = 0;

2
src/training/graph_group_async.h Normal file → Executable file
View File

@ -16,7 +16,7 @@ public:
protected:
bool first_{true};
std::vector<Ptr<models::ModelBase>> builders_;
std::vector<Ptr<models::ICriterionFunction>> builders_;
std::vector<Ptr<ExpressionGraph>> graphs_;
std::vector<DeviceId> devices_;

2
src/training/graph_group_multinode.cpp Normal file → Executable file
View File

@ -512,7 +512,7 @@ void MultiNodeGraphGroup::execute(Ptr<data::Batch> batch) {
auto task = [this](Ptr<data::Batch> batch) {
static size_t i = 0;
thread_local Ptr<ExpressionGraph> graph;
thread_local Ptr<models::ModelBase> builder;
thread_local Ptr<models::ICriterionFunction> builder;
thread_local size_t my_id = 0;
thread_local size_t t = 0;
// only for scheduler statistic

4
src/training/graph_group_singleton.h Normal file → Executable file
View File

@ -16,7 +16,7 @@ public:
virtual void setScheduler(Ptr<Scheduler> scheduler) override;
private:
Ptr<models::ModelBase> builder_;
Ptr<models::ICriterionFunction> builder_;
Ptr<ExpressionGraph> graph_;
Ptr<ExpressionGraph> graphAvg_;
@ -37,7 +37,7 @@ public:
graph_->getBackend()->setClip(options_->get<float>("clip-gemm"));
graph_->reserveWorkspaceMB(options_->get<size_t>("workspace"));
opt_ = Optimizer(options_);
builder_ = models::from_options(options_, models::usage::training);
builder_ = models::createCriterionFunctionFromOptions(options_, models::usage::training);
}
void update(Ptr<data::Batch> batch) override {

2
src/training/graph_group_sync.cpp Normal file → Executable file
View File

@ -15,7 +15,7 @@ SyncGraphGroup::SyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi)
graphs_.push_back(graph);
shardOpt_.push_back(Optimizer(options_));
builders_.push_back(models::from_options(options_, models::usage::training));
builders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training));
}
// Note: We may well end up with only one MPI process or only one graph per worker.

7
src/training/graph_group_sync.h Normal file → Executable file
View File

@ -13,12 +13,11 @@ class SyncGraphGroup : public GraphGroup, public ExponentialSmoothing {
Ptr<ICommunicator> comm_; // [not null] communicator, e.g. NCCLCommunicator
Ptr<IMPIWrapper> mpi_; // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run)
std::vector<DeviceId> devices_; // [deviceIndex]
std::vector<Ptr<models::ModelBase>> builders_; // [deviceIndex]
std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
std::vector<DeviceId> devices_; // [deviceIndex]
std::vector<Ptr<models::ICriterionFunction>> builders_; // [deviceIndex]
std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
std::vector<Ptr<OptimizerBase>> shardOpt_; // [deviceIndex]
std::vector<Tensor> paramsAvg_; // [deviceIndex] exponentially smoothed parameters, sharded
// @TODO: instead, create an array of ExponentialSmoothing objects, and don't use ExponentialSmoothing as a base class
std::vector<Ptr<TensorAllocator>> paramsAllocs_; // [deviceIndex] we must hold a reference to the memory until this class dies

4
src/training/validator.cpp Normal file → Executable file
View File

@ -2,10 +2,10 @@
namespace marian {
std::vector<Ptr<Validator<data::Corpus>>> Validators(
std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> Validators(
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> config) {
std::vector<Ptr<Validator<data::Corpus>>> validators;
std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> validators;
auto validMetrics = config->get<std::vector<std::string>>("valid-metrics");

View File

@ -56,7 +56,7 @@ public:
}
};
template <class DataSet>
template <class DataSet, class BuilderType>
class Validator : public ValidatorBase {
public:
Validator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options, bool lowerIsBetter = true)
@ -123,7 +123,7 @@ public:
protected:
std::vector<Ptr<Vocab>> vocabs_;
Ptr<Options> options_;
Ptr<models::ModelBase> builder_;
Ptr<BuilderType> builder_;
Ptr<data::BatchGenerator<DataSet>> batchGenerator_;
virtual float validateBG(const std::vector<Ptr<ExpressionGraph>>&)
@ -148,7 +148,7 @@ protected:
}
};
class CrossEntropyValidator : public Validator<data::Corpus> {
class CrossEntropyValidator : public Validator<data::Corpus, models::ICriterionFunction> {
public:
CrossEntropyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options) {
@ -159,7 +159,7 @@ public:
opts->merge(options);
opts->set("inference", true);
opts->set("cost-type", "ce-sum");
builder_ = models::from_options(opts, models::usage::scoring);
builder_ = models::createCriterionFunctionFromOptions(opts, models::usage::scoring);
}
std::string type() override { return options_->get<std::string>("cost-type"); }
@ -180,7 +180,7 @@ protected:
for(auto batch : *batchGenerator_) {
auto task = [=, &loss, &samples](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local auto builder = models::from_options(options_, models::usage::scoring);
thread_local auto builder = models::createCriterionFunctionFromOptions(options_, models::usage::scoring);
if(!graph) {
graph = graphs[id % graphs.size()];
@ -215,8 +215,8 @@ protected:
}
};
// Used for validating with classifiers. Compute prediction accuary versus groundtruth for a set of classes
class AccuracyValidator : public Validator<data::Corpus> {
// Used for validating with classifiers. Compute prediction accuracy versus ground truth for a set of classes
class AccuracyValidator : public Validator<data::Corpus, models::IModel> {
public:
AccuracyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, /*lowerIsBetter=*/false) {
@ -226,7 +226,7 @@ public:
Ptr<Options> opts = New<Options>();
opts->merge(options);
opts->set("inference", true);
builder_ = models::from_options(opts, models::usage::raw);
builder_ = models::createModelFromOptions(opts, models::usage::raw);
}
std::string type() override { return "accuracy"; }
@ -245,7 +245,7 @@ protected:
for(auto batch : *batchGenerator_) {
auto task = [=, &correct, &totalLabels](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local auto builder = models::from_options(options_, models::usage::raw);
thread_local auto builder = models::createModelFromOptions(options_, models::usage::raw);
if(!graph) {
graph = graphs[id % graphs.size()];
@ -263,7 +263,7 @@ protected:
// correct += correct->scalar<IndexType>();
builder->clear(graph);
Expr logits = builder->build(graph, batch)->loss();
Expr logits = builder->build(graph, batch);
graph->forward();
std::vector<float> vLogits;
@ -305,7 +305,7 @@ protected:
}
};
class BertAccuracyValidator : public Validator<data::Corpus> {
class BertAccuracyValidator : public Validator<data::Corpus, models::IModel> {
private:
bool evalMaskedLM_{true};
@ -319,7 +319,7 @@ public:
Ptr<Options> opts = New<Options>();
opts->merge(options);
opts->set("inference", true);
builder_ = models::from_options(opts, models::usage::raw);
builder_ = models::createModelFromOptions(opts, models::usage::raw);
}
std::string type() override {
@ -343,7 +343,7 @@ protected:
for(auto batch : *batchGenerator_) {
auto task = [=, &correct, &totalLabels](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local auto builder = models::from_options(options_, models::usage::raw);
thread_local auto builder = models::createModelFromOptions(options_, models::usage::raw);
thread_local std::unique_ptr<std::mt19937> engine;
if(!graph) {
@ -415,11 +415,11 @@ protected:
};
class ScriptValidator : public Validator<data::Corpus> {
class ScriptValidator : public Validator<data::Corpus, models::IModel> {
public:
ScriptValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, false) {
builder_ = models::from_options(options_, models::usage::raw);
builder_ = models::createModelFromOptions(options_, models::usage::raw);
ABORT_IF(!options_->hasAndNotEmpty("valid-script-path"),
"valid-script metric but no script given");
@ -446,12 +446,12 @@ protected:
}
};
class TranslationValidator : public Validator<data::Corpus> {
class TranslationValidator : public Validator<data::Corpus, models::IModel> {
public:
TranslationValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, false),
quiet_(options_->get<bool>("quiet-translation")) {
builder_ = models::from_options(options_, models::usage::translation);
builder_ = models::createModelFromOptions(options_, models::usage::translation);
if(!options_->hasAndNotEmpty("valid-script-path"))
LOG_VALID(warn, "No post-processing script given for validating translator");
@ -475,7 +475,7 @@ public:
std::vector<Ptr<Scorer>> scorers;
for(auto graph : graphs) {
auto builder = models::from_options(options_, models::usage::translation);
auto builder = models::createModelFromOptions(options_, models::usage::translation);
Ptr<Scorer> scorer = New<ScorerWrapper>(builder, "", 1.0f, model);
scorers.push_back(scorer);
}
@ -578,7 +578,7 @@ protected:
};
// @TODO: combine with TranslationValidator (above) to avoid code duplication
class BleuValidator : public Validator<data::Corpus> {
class BleuValidator : public Validator<data::Corpus, models::IModel> {
private:
bool detok_{false};
@ -587,7 +587,7 @@ public:
: Validator(vocabs, options, false),
detok_(detok),
quiet_(options_->get<bool>("quiet-translation")) {
builder_ = models::from_options(options_, models::usage::translation);
builder_ = models::createModelFromOptions(options_, models::usage::translation);
#ifdef USE_SENTENCEPIECE
auto vocab = vocabs_.back();
@ -619,7 +619,7 @@ public:
std::vector<Ptr<Scorer>> scorers;
for(auto graph : graphs) {
auto builder = models::from_options(options_, models::usage::translation);
auto builder = models::createModelFromOptions(options_, models::usage::translation);
Ptr<Scorer> scorer = New<ScorerWrapper>(builder, "", 1.0f, model);
scorers.push_back(scorer);
}
@ -844,7 +844,7 @@ protected:
*
* @return Vector of validator objects
*/
std::vector<Ptr<Validator<data::Corpus>>> Validators(
std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> Validators(
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> config);
} // namespace marian

4
src/translator/scorers.cpp Normal file → Executable file
View File

@ -17,7 +17,7 @@ Ptr<Scorer> scorerByType(const std::string& fname,
}
bool skipCost = options->get<bool>("skip-cost");
auto encdec = models::from_options(
auto encdec = models::createModelFromOptions(
options, skipCost ? models::usage::raw : models::usage::translation);
LOG(info, "Loading scorer of type {} as feature {}", type, fname);
@ -39,7 +39,7 @@ Ptr<Scorer> scorerByType(const std::string& fname,
}
bool skipCost = options->get<bool>("skip-cost");
auto encdec = models::from_options(
auto encdec = models::createModelFromOptions(
options, skipCost ? models::usage::raw : models::usage::translation);
LOG(info, "Loading scorer of type {} as feature {}", type, fname);

View File

@ -72,7 +72,7 @@ private:
const void* ptr_;
public:
ScorerWrapper(Ptr<models::ModelBase> encdec,
ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
const std::string& fname)
@ -81,7 +81,7 @@ public:
fname_(fname),
ptr_{0} {}
ScorerWrapper(Ptr<models::ModelBase> encdec,
ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
const void* ptr)

View File

@ -898,9 +898,11 @@
<ClInclude Include="..\src\models\amun.h" />
<ClInclude Include="..\src\models\bert.h" />
<ClInclude Include="..\src\models\char_s2s.h" />
<ClInclude Include="..\src\models\classifier.h" />
<ClInclude Include="..\src\models\costs.h" />
<ClInclude Include="..\src\models\decoder.h" />
<ClInclude Include="..\src\models\encoder.h" />
<ClInclude Include="..\src\models\encoder_classifier.h" />
<ClInclude Include="..\src\models\encoder_decoder.h" />
<ClInclude Include="..\src\models\model_base.h" />
<ClInclude Include="..\src\models\model_factory.h" />

View File

@ -1528,6 +1528,12 @@
<ClInclude Include="..\src\data\vocab_base.h">
<Filter>data</Filter>
</ClInclude>
<ClInclude Include="..\src\models\classifier.h">
<Filter>models</Filter>
</ClInclude>
<ClInclude Include="..\src\models\encoder_classifier.h">
<Filter>models</Filter>
</ClInclude>
<ClInclude Include="..\src\models\transformer.h">
<Filter>models</Filter>
</ClInclude>