merged from master

This commit is contained in:
Frank Seide 2019-03-14 17:30:44 -07:00
commit ac04bbfa16
26 changed files with 166 additions and 172 deletions

View File

@ -17,11 +17,11 @@ namespace models {
// @TODO: looking at this file, simplify the new RationalLoss idea. Here it gets too complicated
class MNISTCrossEntropyCost : public CostBase {
class MNISTCrossEntropyCost : public ICost {
public:
MNISTCrossEntropyCost() {}
Ptr<MultiRationalLoss> apply(Ptr<ModelBase> model,
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
@ -43,11 +43,11 @@ public:
}
};
class MNISTLogsoftmax : public LogProbBase {
class MNISTLogsoftmax : public ILogProb {
public:
MNISTLogsoftmax() {}
Logits apply(Ptr<ModelBase> model,
Logits apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
@ -56,7 +56,7 @@ public:
}
};
class MnistFeedForwardNet : public ModelBase {
class MnistFeedForwardNet : public IModel {
public:
typedef data::MNISTData dataset_type;

View File

@ -15,9 +15,9 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override { graph->clear(); };
protected:
virtual Expr construct(Ptr<ExpressionGraph> g,
Ptr<data::Batch> batch,
bool inference = false) override {
virtual Expr apply(Ptr<ExpressionGraph> g,
Ptr<data::Batch> batch,
bool inference = false) override {
const std::vector<int> dims = {784, 128, 10};
// Start with an empty expression graph

View File

@ -12,7 +12,7 @@ using namespace marian;
namespace marian {
class MNISTAccuracyValidator : public Validator<data::MNISTData, models::ModelBase> {
class MNISTAccuracyValidator : public Validator<data::MNISTData, models::IModel> {
public:
MNISTAccuracyValidator(Ptr<Options> options) : Validator(std::vector<Ptr<Vocab>>(), options, false) {
createBatchGenerator(/*isTranslating=*/false);

View File

@ -19,15 +19,15 @@ namespace models {
// Other functions return RationalLoss directly without Ptr<...>, but also
// they do not need polymorphism here.
class CostBase {
class ICost {
public:
virtual Ptr<MultiRationalLoss> apply(Ptr<ModelBase> model,
Ptr<ExpressionGraph> graph,
virtual Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
Ptr<data::Batch> batch,
bool clearGraph = true) = 0;
};
class EncoderDecoderCE : public CostBase {
class EncoderDecoderCECost : public ICost {
protected:
Ptr<Options> options_;
@ -39,7 +39,7 @@ protected:
Ptr<WeightingBase> weighter_;
public:
EncoderDecoderCE(Ptr<Options> options)
EncoderDecoderCECost(Ptr<Options> options)
: options_(options), inference_(options->get<bool>("inference", false)) {
loss_ = newLoss(options_, inference_);
@ -51,7 +51,7 @@ public:
weighter_ = WeightingFactory(options_);
}
Ptr<MultiRationalLoss> apply(Ptr<ModelBase> model,
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
@ -89,7 +89,7 @@ public:
};
// Wraps an EncoderClassifier so it can produce a cost from raw logits. @TODO: Needs refactoring
class EncoderClassifierCE : public CostBase {
class EncoderClassifierCECost : public ICost {
protected:
Ptr<Options> options_;
bool inference_{false};
@ -99,12 +99,12 @@ protected:
Ptr<LabelwiseLoss> loss_;
public:
EncoderClassifierCE(Ptr<Options> options)
EncoderClassifierCECost(Ptr<Options> options)
: options_(options), inference_(options->get<bool>("inference", false)) {
loss_ = newLoss(options_, inference_);
}
Ptr<MultiRationalLoss> apply(Ptr<ModelBase> model,
Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
@ -127,16 +127,16 @@ public:
}
};
class Trainer : public CriterionBase {
class Trainer : public ICriterionFunction {
protected:
Ptr<ModelBase> model_;
Ptr<CostBase> cost_;
Ptr<IModel> model_;
Ptr<ICost> cost_;
public:
Trainer(Ptr<ModelBase> model, Ptr<CostBase> cost)
Trainer(Ptr<IModel> model, Ptr<ICost> cost)
: model_(model), cost_(cost) {}
Ptr<ModelBase> getModel() { return model_; }
Ptr<IModel> getModel() { return model_; }
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
@ -159,9 +159,9 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override { model_->clear(graph); };
};
class LogProbBase {
class ILogProb {
public:
virtual Logits apply(Ptr<ModelBase> model,
virtual Logits apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) = 0;
@ -170,16 +170,16 @@ public:
// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth?
// Beam search uses it for the former meaning, while 'marian score' and validation in the latter.
// This class is for the former use. The latter is done using Trainer.
class Scorer : public ModelBase {
class Scorer : public IModel {
protected:
Ptr<ModelBase> model_;
Ptr<LogProbBase> logProb_;
Ptr<IModel> model_;
Ptr<ILogProb> logProb_;
public:
Scorer(Ptr<ModelBase> model, Ptr<LogProbBase> cost)
Scorer(Ptr<IModel> model, Ptr<ILogProb> cost)
: model_(model), logProb_(cost) {}
Ptr<ModelBase> getModel() { return model_; }
Ptr<IModel> getModel() { return model_; }
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
@ -202,14 +202,14 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override { model_->clear(graph); };
};
class CostStep {
class ILogProbStep {
public:
// @BUGBUG: This is not a function application. Rather, it updates 'state' in-place.
// Suggest to call it updateState, and not return the state object.
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) = 0;
};
class LogSoftmaxStep : public CostStep {
class LogSoftmaxStep : public ILogProbStep {
public:
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
@ -221,7 +221,7 @@ public:
// Gumbel-max noising for sampling during beam-search
// Seems to work well enough with beam-size=1. Turn on
// with --output-sampling during translation with marian-decoder
class GumbelSoftmaxStep : public CostStep {
class GumbelSoftmaxStep : public ILogProbStep {
public:
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
// @TODO: @HACK must know about individual parts; make it a loop
@ -234,16 +234,16 @@ public:
}
};
// class to wrap an EncoderDecoderBase and a CostStep that are executed in sequence,
// class to wrap an EncoderDecoderBase and a ILogProbStep that are executed in sequence,
// wrapped again in the EncoderDecoderBase interface
// @TODO: seems we are conflating an interface defition with its implementation?
class Stepwise : public EncoderDecoderBase {
protected:
Ptr<EncoderDecoderBase> encdec_;
Ptr<CostStep> cost_;
Ptr<ILogProbStep> cost_;
public:
Stepwise(Ptr<EncoderDecoderBase> encdec, Ptr<CostStep> cost)
Stepwise(Ptr<EncoderDecoderBase> encdec, Ptr<ILogProbStep> cost)
: encdec_(encdec), cost_(cost) {}
virtual void load(Ptr<ExpressionGraph> graph,

View File

@ -17,7 +17,7 @@ namespace marian {
* @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/classifier
* multi-objective training.
*/
class EncoderClassifierBase : public models::ModelBase {
class EncoderClassifierBase : public models::IModel {
public:
virtual ~EncoderClassifierBase() {}

View File

@ -202,7 +202,7 @@ Logits EncoderDecoder::build(Ptr<ExpressionGraph> graph,
auto state = stepAll(graph, batch, clearGraph);
// returns raw logits
return state->getLogProbs(); // , state->getTargetMask()); // @TODO: hacky hack hack
return state->getLogProbs();
}
Logits EncoderDecoder::build(Ptr<ExpressionGraph> graph,

View File

@ -9,7 +9,7 @@
namespace marian {
class EncoderDecoderBase : public models::ModelBase {
class EncoderDecoderBase : public models::IModel {
public:
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,

View File

@ -18,7 +18,7 @@ namespace marian {
namespace models {
// model = input -> predictions
class ModelBase {
class IModel {
public:
virtual void load(Ptr<ExpressionGraph>,
const std::string&,
@ -38,7 +38,8 @@ public:
};
// criterion = (input, reference) -> loss
class CriterionBase {
// @TODO: Is there a better name?
class ICriterionFunction {
public:
virtual void load(Ptr<ExpressionGraph>,
const std::string&,

View File

@ -60,7 +60,7 @@ Ptr<ClassifierBase> ClassifierFactory::construct(Ptr<ExpressionGraph> /*graph*/)
ABORT("Unknown classifier type");
}
Ptr<ModelBase> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) {
Ptr<IModel> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) {
Ptr<EncoderDecoder> encdec;
if(options_->get<std::string>("type") == "amun")
@ -80,7 +80,7 @@ Ptr<ModelBase> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) {
return encdec;
}
Ptr<ModelBase> EncoderClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
Ptr<IModel> EncoderClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
Ptr<EncoderClassifier> enccls;
if(options_->get<std::string>("type") == "bert") {
enccls = New<BertEncoderClassifier>(options_);
@ -99,19 +99,19 @@ Ptr<ModelBase> EncoderClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
return enccls;
}
Ptr<ModelBase> createBaseModelByType(std::string type, usage use, Ptr<Options> options) {
Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> options) {
Ptr<ExpressionGraph> graph = nullptr; // graph unknown at this stage
// clang-format off
if(type == "s2s" || type == "amun" || type == "nematus") {
return models::encoder_decoder()(options)
("usage", use)
("original-type", type)
.push_back(models::encoder()("type", "s2s"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
.push_back(models::encoder()("type", "s2s"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
}
if(type == "transformer") {
else if(type == "transformer") {
return models::encoder_decoder()(options)
("usage", use)
.push_back(models::encoder()("type", "transformer"))
@ -119,16 +119,16 @@ Ptr<ModelBase> createBaseModelByType(std::string type, usage use, Ptr<Options> o
.construct(graph);
}
if(type == "transformer_s2s") {
else if(type == "transformer_s2s") {
return models::encoder_decoder()(options)
("usage", use)
("original-type", type)
.push_back(models::encoder()("type", "transformer"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
.push_back(models::encoder()("type", "transformer"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
}
if(type == "lm") {
else if(type == "lm") {
auto idx = options->has("index") ? options->get<size_t>("index") : 0;
std::vector<int> dimVocabs = options->get<std::vector<int>>("dim-vocabs");
int vocab = dimVocabs[0];
@ -139,13 +139,13 @@ Ptr<ModelBase> createBaseModelByType(std::string type, usage use, Ptr<Options> o
("usage", use)
("type", "s2s")
("original-type", type)
.push_back(models::decoder()
("index", idx)
("dim-vocabs", dimVocabs))
.construct(graph);
.push_back(models::decoder()
("index", idx)
("dim-vocabs", dimVocabs))
.construct(graph);
}
if(type == "multi-s2s") {
else if(type == "multi-s2s") {
size_t numEncoders = 2;
auto ms2sFactory = models::encoder_decoder()(options)
("usage", use)
@ -162,7 +162,7 @@ Ptr<ModelBase> createBaseModelByType(std::string type, usage use, Ptr<Options> o
return ms2sFactory.construct(graph);
}
if(type == "shared-multi-s2s") {
else if(type == "shared-multi-s2s") {
size_t numEncoders = 2;
auto ms2sFactory = models::encoder_decoder()(options)
("usage", use)
@ -179,7 +179,7 @@ Ptr<ModelBase> createBaseModelByType(std::string type, usage use, Ptr<Options> o
return ms2sFactory.construct(graph);
}
if(type == "multi-transformer") {
else if(type == "multi-transformer") {
size_t numEncoders = 2;
auto mtransFactory = models::encoder_decoder()(options)
("usage", use)
@ -195,7 +195,7 @@ Ptr<ModelBase> createBaseModelByType(std::string type, usage use, Ptr<Options> o
return mtransFactory.construct(graph);
}
if(type == "shared-multi-transformer") {
else if(type == "shared-multi-transformer") {
size_t numEncoders = 2;
auto mtransFactory = models::encoder_decoder()(options)
("usage", use)
@ -211,7 +211,7 @@ Ptr<ModelBase> createBaseModelByType(std::string type, usage use, Ptr<Options> o
return mtransFactory.construct(graph);
}
if(type == "lm-transformer") {
else if(type == "lm-transformer") {
auto idx = options->has("index") ? options->get<size_t>("index") : 0;
std::vector<int> dimVocabs = options->get<std::vector<int>>("dim-vocabs");
int vocab = dimVocabs[0];
@ -222,59 +222,68 @@ Ptr<ModelBase> createBaseModelByType(std::string type, usage use, Ptr<Options> o
("usage", use)
("type", "transformer")
("original-type", type)
.push_back(models::decoder()
("index", idx)
("dim-vocabs", dimVocabs))
.construct(graph);
.push_back(models::decoder()
("index", idx)
("dim-vocabs", dimVocabs))
.construct(graph);
}
if(type == "bert") { // for full BERT training
else if(type == "bert") { // for full BERT training
return models::encoder_classifier()(options) //
("original-type", "bert") // so we can query this
("usage", use) //
.push_back(models::encoder() //
("type", "bert-encoder") // close to original transformer encoder
("index", 0)) //
("type", "bert-encoder") // close to original transformer encoder
("index", 0)) //
.push_back(models::classifier() //
("prefix", "masked-lm") // prefix for parameter names
("type", "bert-masked-lm") //
("index", 0)) // multi-task learning with MaskedLM
("prefix", "masked-lm") // prefix for parameter names
("type", "bert-masked-lm") //
("index", 0)) // multi-task learning with MaskedLM
.push_back(models::classifier() //
("prefix", "next-sentence") // prefix for parameter names
("type", "bert-classifier") //
("index", 1)) // next sentence prediction
("prefix", "next-sentence") // prefix for parameter names
("type", "bert-classifier") //
("index", 1)) // next sentence prediction
.construct(graph);
}
if(type == "bert-classifier") { // for BERT fine-tuning on non-BERT classification task
else if(type == "bert-classifier") { // for BERT fine-tuning on non-BERT classification task
return models::encoder_classifier()(options) //
("original-type", "bert-classifier") // so we can query this if needed
("usage", use) //
.push_back(models::encoder() //
("type", "bert-encoder") //
("index", 0)) // close to original transformer encoder
("type", "bert-encoder") //
("index", 0)) // close to original transformer encoder
.push_back(models::classifier() //
("type", "bert-classifier") //
("index", 1)) // next sentence prediction
("type", "bert-classifier") //
("index", 1)) // next sentence prediction
.construct(graph);
}
#ifdef COMPILE_EXAMPLES
else if(type == "mnist-ffnn")
return New<MnistFeedForwardNet>(options);
#endif
#ifdef CUDNN
if(type == "char-s2s") {
#ifdef COMPILE_EXAMPLES
else if(type == "mnist-lenet")
return New<MnistLeNet>(options);
#endif
else if(type == "char-s2s") {
return models::encoder_decoder()(options)
("usage", use)
("original-type", type)
.push_back(models::encoder()("type", "char-s2s"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
.push_back(models::encoder()("type", "char-s2s"))
.push_back(models::decoder()("type", "s2s"))
.construct(graph);
}
#endif
// clang-format on
ABORT("Unknown model type: {}", type);
else
ABORT("Unknown model type: {}", type);
}
Ptr<ModelBase> createModelFromOptions(Ptr<Options> options, usage use) {
Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
std::string type = options->get<std::string>("type");
auto baseModel = createBaseModelByType(type, use, options);
@ -304,7 +313,7 @@ Ptr<ModelBase> createModelFromOptions(Ptr<Options> options, usage use) {
ABORT("'Usage' parameter must be 'translation' or 'raw'");
}
Ptr<CriterionBase> createCriterionFromOptions(Ptr<Options> options, usage use) {
Ptr<ICriterionFunction> createCriterionFunctionFromOptions(Ptr<Options> options, usage use) {
std::string type = options->get<std::string>("type");
auto baseModel = createBaseModelByType(type, use, options);
@ -313,9 +322,9 @@ Ptr<CriterionBase> createCriterionFromOptions(Ptr<Options> options, usage use) {
// note: usage::scoring means "score the loss function", hence it uses a Trainer (not Scorer, which is for decoding)
// @TODO: Should we define a new class that does not compute gradients?
if (std::dynamic_pointer_cast<EncoderDecoder>(baseModel))
return New<Trainer>(baseModel, New<EncoderDecoderCE>(options));
return New<Trainer>(baseModel, New<EncoderDecoderCECost>(options));
else if (std::dynamic_pointer_cast<EncoderClassifier>(baseModel))
return New<Trainer>(baseModel, New<EncoderClassifierCE>(options));
return New<Trainer>(baseModel, New<EncoderClassifierCECost>(options));
#ifdef COMPILE_EXAMPLES
// @TODO: examples should be compiled optionally
else if (std::dynamic_pointer_cast<MnistFeedForwardNet>(baseModel))

View File

@ -56,7 +56,7 @@ public:
return Accumulator<EncoderDecoderFactory>(*this);
}
virtual Ptr<ModelBase> construct(Ptr<ExpressionGraph> graph);
virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
};
typedef Accumulator<EncoderDecoderFactory> encoder_decoder;
@ -80,15 +80,15 @@ public:
return Accumulator<EncoderClassifierFactory>(*this);
}
virtual Ptr<ModelBase> construct(Ptr<ExpressionGraph> graph);
virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
};
typedef Accumulator<EncoderClassifierFactory> encoder_classifier;
Ptr<ModelBase> createBaseModelByType(std::string type, usage, Ptr<Options> options);
Ptr<IModel> createBaseModelByType(std::string type, usage, Ptr<Options> options);
Ptr<ModelBase> createModelFromOptions(Ptr<Options> options, usage);
Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage);
Ptr<CriterionBase> createCriterionFromOptions(Ptr<Options> options, usage);
Ptr<ICriterionFunction> createCriterionFunctionFromOptions(Ptr<Options> options, usage);
} // namespace models
} // namespace marian

View File

@ -19,11 +19,11 @@ using namespace data;
class Rescorer {
private:
Ptr<models::CriterionBase> builder_;
Ptr<models::ICriterionFunction> builder_;
public:
Rescorer(Ptr<Options> options)
: builder_(models::createCriterionFromOptions(options, models::usage::scoring)) {}
: builder_(models::createCriterionFunctionFromOptions(options, models::usage::scoring)) {}
void load(Ptr<ExpressionGraph> graph, const std::string& modelFile) {
builder_->load(graph, modelFile);

View File

@ -15,7 +15,7 @@ namespace marian {
namespace cpu {
void IsNan(const Tensor in, Ptr<Allocator> allocator, bool& isNan, bool& isInf, bool zero) {
zero; isInf; isNan;
isNan; isInf; zero;
ABORT("Not implemented");
}

7
src/training/communicator.cpp Normal file → Executable file
View File

@ -79,13 +79,12 @@ public:
HANDLE_MPI_ERROR(MPI_Init_thread(&argc, &argvp, MPI_THREAD_MULTIPLE, &providedThreadingMode));
MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN); // have errors reported as return codes
ABORT_IF(
providedThreadingMode < requiredThreadingMode,
"Your version of MPI does not support multi-threaded communication.");
MPI_Comm_size(MPI_COMM_WORLD, &comm_world_size_);
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank_);
ABORT_IF(comm_world_size_ > 1 && providedThreadingMode < requiredThreadingMode,
"Your version of MPI does not support multi-threaded communication.");
// patch logging pattern to include the MPI rank, so that we can associate error messages with nodes
if (numMPIProcesses() > 1) {
std::string rankStr = std::to_string(MPIWrapper::myMPIRank());

View File

@ -55,7 +55,7 @@ public:
*/
// @TODO: Can this be made const? It seems wrong to have a stateful method that still returns a result.
virtual Ptr<data::BatchStats> collectStats(Ptr<ExpressionGraph> graph,
Ptr<models::CriterionBase> model,
Ptr<models::ICriterionFunction> model,
const std::vector<Ptr<Vocab>>& vocabs,
double multiplier = 1.) {
auto stats = New<data::BatchStats>();
@ -140,7 +140,7 @@ protected:
std::vector<size_t> devices_; // [num local GPUs]
/** Graph builders for clients (which run forward and backward passes). */
std::vector<Ptr<models::CriterionBase>> clientBuilders_;
std::vector<Ptr<models::ICriterionFunction>> clientBuilders_;
/** Graphs of clients. One entry per GPU on this node. */
std::vector<Ptr<ExpressionGraph>> clientGraphs_; // [num local GPUs]
@ -160,7 +160,7 @@ public:
clientGraphs_.push_back(New<ExpressionGraph>());
clientGraphs_[i]->setDevice({ devices_[i], DeviceType::gpu });
clientGraphs_[i]->reserveWorkspaceMB(options_->get<size_t>("workspace"));
clientBuilders_.push_back(models::createCriterionFromOptions(options_, models::usage::training));
clientBuilders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training));
}
}

View File

@ -23,7 +23,7 @@ AsyncGraphGroup::AsyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi)
graphs_.push_back(graph);
shardOpt_.push_back(Optimizer(options_));
builders_.push_back(models::createCriterionFromOptions(options_, models::usage::training));
builders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training));
}
}
@ -189,7 +189,7 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) {
auto task = [this](Ptr<data::Batch> batch) {
static size_t i = 0;
thread_local Ptr<ExpressionGraph> graph;
thread_local Ptr<models::CriterionBase> builder;
thread_local Ptr<models::ICriterionFunction> builder;
thread_local size_t t = 0;
thread_local size_t num_seen_words = 0;
thread_local size_t num_seen_sentences = 0;

View File

@ -16,7 +16,7 @@ public:
protected:
bool first_{true};
std::vector<Ptr<models::CriterionBase>> builders_;
std::vector<Ptr<models::ICriterionFunction>> builders_;
std::vector<Ptr<ExpressionGraph>> graphs_;
std::vector<DeviceId> devices_;

View File

@ -512,7 +512,7 @@ void MultiNodeGraphGroup::execute(Ptr<data::Batch> batch) {
auto task = [this](Ptr<data::Batch> batch) {
static size_t i = 0;
thread_local Ptr<ExpressionGraph> graph;
thread_local Ptr<models::CriterionBase> builder;
thread_local Ptr<models::ICriterionFunction> builder;
thread_local size_t my_id = 0;
thread_local size_t t = 0;
// only for scheduler statistic

View File

@ -16,7 +16,7 @@ public:
virtual void setScheduler(Ptr<Scheduler> scheduler) override;
private:
Ptr<models::CriterionBase> builder_;
Ptr<models::ICriterionFunction> builder_;
Ptr<ExpressionGraph> graph_;
Ptr<ExpressionGraph> graphAvg_;
@ -37,7 +37,7 @@ public:
graph_->getBackend()->setClip(options_->get<float>("clip-gemm"));
graph_->reserveWorkspaceMB(options_->get<size_t>("workspace"));
opt_ = Optimizer(options_);
builder_ = models::createCriterionFromOptions(options_, models::usage::training);
builder_ = models::createCriterionFunctionFromOptions(options_, models::usage::training);
}
void update(Ptr<data::Batch> batch) override {

View File

@ -15,7 +15,7 @@ SyncGraphGroup::SyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi)
graphs_.push_back(graph);
shardOpt_.push_back(Optimizer(options_));
builders_.push_back(models::createCriterionFromOptions(options_, models::usage::training));
builders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training));
}
// Note: We may well end up with only one MPI process or only one graph per worker.

View File

@ -13,12 +13,11 @@ class SyncGraphGroup : public GraphGroup, public ExponentialSmoothing {
Ptr<ICommunicator> comm_; // [not null] communicator, e.g. NCCLCommunicator
Ptr<IMPIWrapper> mpi_; // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run)
std::vector<DeviceId> devices_; // [deviceIndex]
std::vector<Ptr<models::CriterionBase>> builders_; // [deviceIndex]
std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
std::vector<DeviceId> devices_; // [deviceIndex]
std::vector<Ptr<models::ICriterionFunction>> builders_; // [deviceIndex]
std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
std::vector<Ptr<OptimizerBase>> shardOpt_; // [deviceIndex]
std::vector<Tensor> paramsAvg_; // [deviceIndex] exponentially smoothed parameters, sharded
// @TODO: instead, create an array of ExponentialSmoothing objects, and don't use ExponentialSmoothing as a base class
std::vector<Ptr<TensorAllocator>> paramsAllocs_; // [deviceIndex] we must hold a reference to the memory until this class dies

View File

@ -148,7 +148,7 @@ protected:
}
};
class CrossEntropyValidator : public Validator<data::Corpus, models::CriterionBase> {
class CrossEntropyValidator : public Validator<data::Corpus, models::ICriterionFunction> {
public:
CrossEntropyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options) {
@ -159,7 +159,7 @@ public:
opts->merge(options);
opts->set("inference", true);
opts->set("cost-type", "ce-sum");
builder_ = models::createCriterionFromOptions(opts, models::usage::scoring);
builder_ = models::createCriterionFunctionFromOptions(opts, models::usage::scoring);
}
std::string type() override { return options_->get<std::string>("cost-type"); }
@ -180,7 +180,7 @@ protected:
for(auto batch : *batchGenerator_) {
auto task = [=, &loss, &samples](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local auto builder = models::createCriterionFromOptions(options_, models::usage::scoring);
thread_local auto builder = models::createCriterionFunctionFromOptions(options_, models::usage::scoring);
if(!graph) {
graph = graphs[id % graphs.size()];
@ -216,7 +216,7 @@ protected:
};
// Used for validating with classifiers. Compute prediction accuracy versus ground truth for a set of classes
class AccuracyValidator : public Validator<data::Corpus, models::ModelBase> {
class AccuracyValidator : public Validator<data::Corpus, models::IModel> {
public:
AccuracyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, /*lowerIsBetter=*/false) {
@ -305,7 +305,7 @@ protected:
}
};
class BertAccuracyValidator : public Validator<data::Corpus, models::ModelBase> {
class BertAccuracyValidator : public Validator<data::Corpus, models::IModel> {
private:
bool evalMaskedLM_{true};
@ -415,7 +415,7 @@ protected:
};
class ScriptValidator : public Validator<data::Corpus, models::ModelBase> {
class ScriptValidator : public Validator<data::Corpus, models::IModel> {
public:
ScriptValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, false) {
@ -446,7 +446,7 @@ protected:
}
};
class TranslationValidator : public Validator<data::Corpus, models::ModelBase> {
class TranslationValidator : public Validator<data::Corpus, models::IModel> {
public:
TranslationValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
: Validator(vocabs, options, false),
@ -578,7 +578,7 @@ protected:
};
// @TODO: combine with TranslationValidator (above) to avoid code duplication
class BleuValidator : public Validator<data::Corpus, models::ModelBase> {
class BleuValidator : public Validator<data::Corpus, models::IModel> {
private:
bool detok_{false};

View File

@ -19,7 +19,7 @@ private:
Word trgEosId_{Word::NONE};
Word trgUnkId_{Word::NONE};
static constexpr auto INVALID_PATH_SCORE = -9999;
static constexpr auto INVALID_PATH_SCORE = -9999; // (@TODO: change to -9999.0 once C++ allows that)
public:
BeamSearch(Ptr<Options> options,
@ -37,7 +37,7 @@ public:
// combine new expandedPathScores and previous beams into new set of beams
Beams toHyps(const std::vector<unsigned int>& nBestKeys, // [dimBatch, beamSize] flattened -> ((batchIdx, beamHypIdx) flattened, word idx) flattened
const std::vector<float>& nBestPathScores, // [dimBatch, beamSize] flattened
const size_t inputBeamSize, // for interpretation of nBestKeys
const size_t nBestBeamSize, // for interpretation of nBestKeys
const size_t vocabSize, // ditto.
const Beams& beams,
const std::vector<Ptr<ScorerState /*const*/>>& states,
@ -48,19 +48,19 @@ public:
align = scorers_[0]->getAlignment(); // use alignments from the first scorer, even if ensemble
const auto dimBatch = beams.size();
Beams newBeams(dimBatch);
Beams newBeams(dimBatch); // return value of this function goes here
for(size_t i = 0; i < nBestKeys.size(); ++i) { // [dimBatch, beamSize] flattened
// Keys encode batchIdx, beamHypIdx, and word index in the entire beam.
// They can be between 0 and beamSize * vocabSize-1.
const float pathScore = nBestPathScores[i];
const auto key = nBestKeys[i]; // key = pathScore's tensor location, as (batchIdx, beamHypIdx, word idx) flattened
// They can be between 0 and (vocabSize * nBestBeamSize * batchSize)-1.
// (beamHypIdx refers to the GPU tensors, *not* the beams[] array; they are not the same in case of purging)
const auto key = nBestKeys[i];
const float pathScore = nBestPathScores[i]; // expanded path score for (batchIdx, beamHypIdx, word)
// decompose key into individual indices (batchIdx, beamHypIdx, wordIdx)
const auto wordIdx = (WordIndex)(key % vocabSize);
const auto beamHypIdx = (key / vocabSize) % inputBeamSize;
const auto batchIdx = (key / vocabSize) / inputBeamSize;
//LOG(info, "key = (batch {}, beam {}, word {}) -> {}", batchIdx, beamHypIdx, wordIdx, pathScore);
const auto beamHypIdx = (key / vocabSize) % nBestBeamSize;
const auto batchIdx = (key / vocabSize) / nBestBeamSize;
const auto& beam = beams[batchIdx];
auto& newBeam = newBeams[batchIdx];
@ -70,7 +70,7 @@ public:
if (pathScore <= INVALID_PATH_SCORE) // (dummy slot or word that cannot be expanded by current factor)
continue;
ABORT_IF(beamHypIdx >= (int)beam.size(), "Out of bounds beamHypIdx value {} in key?? word={}, batch={}, pathScore={}", beamHypIdx, wordIdx, batchIdx, pathScore);
ABORT_IF(beamHypIdx >= beam.size(), "Out of bounds beamHypIdx??");
// map wordIdx to word
auto prevHyp = beam[beamHypIdx];
@ -81,23 +81,6 @@ public:
auto shortlist = scorers_[0]->getShortlist();
if (shortlist)
word = Word::fromWordIndex(shortlist->reverseMap(wordIdx));
else if (factoredVocab) {
// For factored decoding, the word is built over multiple decoding steps,
// starting with the lemma, then adding factors one by one.
if (factorGroup == 0) {
word = factoredVocab->lemma2Word(wordIdx);
//LOG(info, "new lemma {}={}", word.toWordIndex(), factoredVocab->word2string(word));
}
else {
//LOG(info, "expand word {}={} with factor[{}] {}", beam[beamHypIdx]->getWord().toWordIndex(),
// factoredVocab->word2string(beam[beamHypIdx]->getWord()), factorGroup, wordIdx);
word = beam[beamHypIdx]->getWord();
ABORT_IF(!factoredVocab->canExpandFactoredWord(word, factorGroup), "A word without this factor snuck through to here??");
word = factoredVocab->expandFactoredWord(word, factorGroup, wordIdx);
prevBeamHypIdx = prevHyp->getPrevStateIndex();
prevHyp = prevHyp->getPrevHyp(); // short-circuit the backpointer, so that the traceback doesnot contain partially factored words
}
}
else
word = Word::fromWordIndex(wordIdx);
@ -244,7 +227,7 @@ public:
for(int i = 0; i < dimBatch; ++i)
histories[i]->add(beams[i], trgEosId_);
// the decoder updates the following state information in each output time step:
// the decoding process updates the following state information in each output time step:
// - beams: array [dimBatch] of array [localBeamSize] of Hypothesis
// - current output time step's set of active hypotheses, aka active search space
// - states[.]: ScorerState
@ -381,7 +364,7 @@ public:
// combine N-best sets with existing search space (beams) to updated search space
beams = toHyps(nBestKeys, nBestPathScores,
/*inputBeamSize*/expandedPathScores->shape()[-2], // used for interpretation of keys
/*nBestBeamSize*/expandedPathScores->shape()[-2], // used for interpretation of keys
/*vocabSize=*/expandedPathScores->shape()[-1], // used for interpretation of keys
beams,
states, // used for keeping track of per-ensemble-member path score

View File

@ -6,7 +6,7 @@
namespace marian {
// one single (possibly partial) hypothesis in beam search
// one single (partial or full) hypothesis in beam search
// key elements:
// - the word that this hyp ends with
// - the aggregate score up to and including the word
@ -29,7 +29,7 @@ public:
float getPathScore() const { return pathScore_; }
std::vector<float>& getScoreBreakdown() { return scoreBreakdown_; }
std::vector<float>& getScoreBreakdown() { return scoreBreakdown_; } // @TODO: make this const
void setScoreBreakdown(const std::vector<float>& scoreBreaddown) { scoreBreakdown_ = scoreBreaddown; }
const std::vector<float>& getAlignment() { return alignment_; }

View File

@ -367,7 +367,7 @@ public:
ABORT_IF(inputN != (isFirst ? 1 : N), "Input tensor has wrong beam dim??"); // @TODO: Remove isFirst argument altogether
ABORT_IF(vocabSize > MAX_VOCAB_SIZE, "GetNBestList(): actual vocab size {} exceeds MAX_VOCAB_SIZE of {}", vocabSize, MAX_VOCAB_SIZE);
ABORT_IF(dimBatch > maxBatchSize_, "GetNBestList(): actual batch size {} exceeds initialization parameter {}", dimBatch, maxBatchSize_);
ABORT_IF(N > maxBeamSize_, "GetNBestList(): actual beam size {} exceeds initialization parameter {}", N, maxBeamSize_); // @TODO: or inputN?
ABORT_IF(std::max(N, (size_t)inputN) > maxBeamSize_, "GetNBestList(): actual beam size {} exceeds initialization parameter {}", N, maxBeamSize_);
const std::vector<size_t> beamSizes(dimBatch, N);
std::vector<int> cumulativeBeamSizes(beamSizes.size() + 1, 0);
@ -440,19 +440,19 @@ private:
const int BLOCK_SIZE = 512;
const int NUM_BLOCKS;
int* d_ind;
float* d_out;
int* d_ind; // [maxBatchSize * NUM_BLOCKS]
float* d_out; // [maxBatchSize * NUM_BLOCKS]
int* d_res_idx;
float* d_res;
int* d_res_idx; // [maxBatchSize * maxBeamSize]
float* d_res; // [maxBatchSize * maxBeamSize]
int* h_res_idx;
float* h_res;
int* h_res_idx; // [maxBeamSize * maxBatchSize]
float* h_res; // [maxBeamSize * maxBatchSize]
float* d_breakdown;
int* d_batchPosition;
int* d_cumBeamSizes;
float* d_breakdown; // [maxBeamSize]
int* d_batchPosition; // [maxBatchSize + 1]
int* d_cumBeamSizes; // [maxBatchSize + 1]
//size_t lastN;
};

View File

@ -72,7 +72,7 @@ private:
const void* ptr_;
public:
ScorerWrapper(Ptr<models::ModelBase> encdec,
ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
const std::string& fname)
@ -81,7 +81,7 @@ public:
fname_(fname),
ptr_{0} {}
ScorerWrapper(Ptr<models::ModelBase> encdec,
ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
const void* ptr)

View File

@ -1522,12 +1522,6 @@
<ClInclude Include="..\src\examples\mnist\validator.h">
<Filter>examples\mnist</Filter>
</ClInclude>
<ClInclude Include="..\src\models\classifier.h">
<Filter>models</Filter>
</ClInclude>
<ClInclude Include="..\src\models\encoder_classifier.h">
<Filter>models</Filter>
</ClInclude>
<ClInclude Include="..\src\command\marian_train.cpp">
<Filter>command</Filter>
</ClInclude>
@ -1543,6 +1537,15 @@
<ClInclude Include="..\src\data\factored_vocab.h">
<Filter>data</Filter>
</ClInclude>
<ClInclude Include="..\src\models\classifier.h">
<Filter>models</Filter>
</ClInclude>
<ClInclude Include="..\src\models\encoder_classifier.h">
<Filter>models</Filter>
</ClInclude>
<ClInclude Include="..\src\models\transformer.h">
<Filter>models</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="3rd_party">