mirror of
https://github.com/marian-nmt/marian.git
synced 2024-10-26 09:09:10 +03:00
sync with internal branch
This commit is contained in:
commit
f4ea8239c4
@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
- Gradient-checkpointing
|
||||
|
||||
### Fixed
|
||||
- Replace value for INVALID_PATH_SCORE with std::numer_limits<float>::lowest()
|
||||
to avoid overflow with long sequences
|
||||
- Break up potential circular references for GraphGroup*
|
||||
- Fix empty source batch entries with batch purging
|
||||
- Clear RNN chache in transformer model, add correct hash functions to nodes
|
||||
- Gather-operation for all index sizes
|
||||
@ -68,6 +71,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
- Dropped support for g++-4.9
|
||||
- Simplified file stream and temporary file handling
|
||||
- Unified node intializers, same function API.
|
||||
- Remove overstuff/understuff code
|
||||
|
||||
## [1.8.0] - 2019-09-04
|
||||
|
||||
|
@ -772,12 +772,6 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) {
|
||||
{"0"});
|
||||
cli.add<bool>("--mini-batch-track-lr",
|
||||
"Dynamically track mini-batch size inverse to actual learning rate (not considering lr-warmup)");
|
||||
cli.add<size_t>("--mini-batch-overstuff",
|
||||
"[experimental] Stuff this much more data into a minibatch, but scale down the LR and progress counter",
|
||||
1);
|
||||
cli.add<size_t>("--mini-batch-understuff",
|
||||
"[experimental] Break each batch into this many updates",
|
||||
1);
|
||||
}
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ static double roundUpRatio(double ratio) {
|
||||
// helper routine that handles accumulation and load-balancing of sub-batches to fill all devices
|
||||
// It adds 'newBatch' to 'pendingBatches_', and if sufficient batches have been queued, then
|
||||
// returns 'pendingBatches_' in 'subBatches' and resets it. If not, it returns false.
|
||||
bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuff,
|
||||
bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch,
|
||||
std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches) {
|
||||
// The reader delivers in chunks of these sizes, according to case:
|
||||
// - no dynamic MB-size scaling:
|
||||
@ -199,9 +199,6 @@ bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuf
|
||||
ratio *= (double)refBatchLabels / (double)(typicalTrgBatchWords_ * updateMultiplier_);
|
||||
}
|
||||
|
||||
// overstuff: blow up ratio by a factor, which we later factor into the learning rate
|
||||
ratio *= (double)overstuff;
|
||||
|
||||
// round up to full batches if within a certain error margin --@BUGBUG: Not invariant w.r.t. GPU size, as ratio is relative to what fits into 1 GPU
|
||||
ratio = roundUpRatio(ratio);
|
||||
|
||||
@ -267,41 +264,18 @@ bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuf
|
||||
void SyncGraphGroup::update(Ptr<data::Batch> newBatch) /*override*/ {
|
||||
validate();
|
||||
|
||||
size_t overstuff = options_->get<size_t>("mini-batch-overstuff");
|
||||
if (overstuff != 1)
|
||||
LOG_ONCE(info, "Overstuffing minibatches by a factor of {}", overstuff);
|
||||
std::vector<Ptr<data::Batch>> subBatches;
|
||||
size_t numReadBatches; // actual #batches delivered by reader, for restoring from checkpoint --@TODO: reader should checkpoint itself; should not go via the scheduler
|
||||
bool gotSubBatches = tryGetSubBatches(newBatch, overstuff, subBatches, numReadBatches);
|
||||
bool gotSubBatches = tryGetSubBatches(newBatch, subBatches, numReadBatches);
|
||||
|
||||
// not enough data yet: return right away
|
||||
if (!gotSubBatches)
|
||||
return;
|
||||
|
||||
// for testing the hypothesis that one can always go smaller. This is independent of overstuff.
|
||||
size_t understuff = options_->get<size_t>("mini-batch-understuff");
|
||||
if (understuff != 1)
|
||||
LOG_ONCE(info, "Understuffing minibatches by a factor of {}", understuff);
|
||||
if (understuff == 1)
|
||||
update(subBatches, numReadBatches);
|
||||
else {
|
||||
std::vector<Ptr<data::Batch>> subBatches1;
|
||||
for (auto& b : subBatches) {
|
||||
auto bbs = b->split(understuff);
|
||||
for (auto& bb : bbs)
|
||||
subBatches1.push_back(bb);
|
||||
}
|
||||
for (size_t i = 0; i < understuff; i++) {
|
||||
std::vector<Ptr<data::Batch>> subBatchRange(subBatches1.begin() + i * subBatches1.size() / understuff, subBatches1.begin() + (i+1) * subBatches1.size() / understuff);
|
||||
if (!subBatchRange.empty())
|
||||
update(subBatchRange, numReadBatches * (i+1) / understuff - numReadBatches * i / understuff);
|
||||
}
|
||||
}
|
||||
update(subBatches, numReadBatches);
|
||||
}
|
||||
|
||||
void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t numReadBatches) {
|
||||
size_t overstuff = options_->get<size_t>("mini-batch-overstuff");
|
||||
//size_t understuff = options_->get<size_t>("mini-batch-understuff");
|
||||
// determine num words for dynamic hyper-parameter adjustment
|
||||
// @TODO: We can return these directly from tryGetSubBatches()
|
||||
size_t batchSize = 0;
|
||||
@ -310,9 +284,6 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
|
||||
batchSize += batch->size();
|
||||
batchTrgWords += batch->wordsTrg();
|
||||
}
|
||||
// effective batch size: batch should be weighted like this. This will weight down the learning rate.
|
||||
size_t effectiveBatchTrgWords = (size_t)ceil(batchTrgWords / (double)overstuff);
|
||||
size_t effectiveBatchSize = (size_t)ceil(batchSize / (double)overstuff);
|
||||
|
||||
// Helper to access the subBatches array
|
||||
auto getSubBatch = [&](size_t warp, size_t localDeviceIndex, size_t rank) -> Ptr<data::Batch> {
|
||||
@ -353,32 +324,21 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
|
||||
auto rationalLoss = builders_[localDeviceIndex]->build(graph, subBatch);
|
||||
graph->forward();
|
||||
|
||||
StaticLoss tempLoss = *rationalLoss; // needed for overstuff
|
||||
tempLoss.loss /= (float)overstuff; // @TODO: @fseide: scale only loss? should this scale labels too?
|
||||
|
||||
localDeviceLosses[localDeviceIndex] += tempLoss;
|
||||
localDeviceLosses[localDeviceIndex] += *rationalLoss;
|
||||
graph->backward(/*zero=*/false); // (gradients are reset before we get here)
|
||||
}
|
||||
});
|
||||
// At this point, each device on each MPI process has a gradient aggregated over a subset of the sub-batches.
|
||||
|
||||
// only needed for overstuff now
|
||||
float div = (float)overstuff; // (note: with Adam, a constant here makes no difference)
|
||||
|
||||
// Update parameter shard with gradient shard
|
||||
auto update = [&](size_t idx, size_t begin, size_t end) {
|
||||
auto curGrad = graphs_[idx]->params()->grads()->subtensor(begin, end-begin);
|
||||
auto curParam = graphs_[idx]->params()->vals()->subtensor(begin, end-begin);
|
||||
|
||||
if(div != 1.f) {
|
||||
using namespace functional;
|
||||
Element(_1 = _1 / div, curGrad); // average if overstuffed
|
||||
}
|
||||
|
||||
// actual model update
|
||||
auto updateTrgWords =
|
||||
/*if*/(options_->get<std::string>("cost-type") == "ce-sum") ?
|
||||
effectiveBatchTrgWords // if overstuffing then bring the count back to the original value
|
||||
batchTrgWords
|
||||
/*else*/:
|
||||
OptimizerBase::mbSizeNotProvided;
|
||||
shardOpt_[idx]->update(curParam, curGrad, updateTrgWords);
|
||||
@ -405,7 +365,7 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
|
||||
|
||||
if(scheduler_) {
|
||||
// track and log localLoss
|
||||
scheduler_->update(localLoss, numReadBatches, effectiveBatchSize, effectiveBatchTrgWords, mpi_);
|
||||
scheduler_->update(localLoss, numReadBatches, batchSize, batchTrgWords, mpi_);
|
||||
|
||||
// save intermediate model (and optimizer state) to file
|
||||
if(scheduler_->saving())
|
||||
|
2
src/training/graph_group_sync.h
Normal file → Executable file
2
src/training/graph_group_sync.h
Normal file → Executable file
@ -35,7 +35,7 @@ class SyncGraphGroup : public GraphGroup, public ExponentialSmoothing {
|
||||
void barrier() const { mpi_->barrier(); } // (we need this several times)
|
||||
void swapParamsAvg() { if (mvAvg_ && paramsAvg_.size() > 0) comm_->swapParams(paramsAvg_); } // note: must call this on all MPI ranks in parallel
|
||||
|
||||
bool tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuff, std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches);
|
||||
bool tryGetSubBatches(Ptr<data::Batch> newBatch, std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches);
|
||||
void update(std::vector<Ptr<data::Batch>> subBatches, size_t numReadBatches);
|
||||
|
||||
public:
|
||||
|
@ -196,8 +196,7 @@ public:
|
||||
|
||||
registerTrainingObserver(validators_.back());
|
||||
if(!state_->loaded) {
|
||||
state_->validators[validator->type()]["last-best"]
|
||||
= validator->initScore();
|
||||
state_->validators[validator->type()]["last-best"] = validator->initScore();
|
||||
state_->validators[validator->type()]["stalled"] = 0;
|
||||
}
|
||||
if(validators_.size() == 1)
|
||||
@ -215,12 +214,12 @@ public:
|
||||
}
|
||||
|
||||
void validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
|
||||
bool final = false) {
|
||||
bool isFinal = false) {
|
||||
// Do not validate if already validated (for instance, after the model is
|
||||
// loaded) or if validation is scheduled for another update, or when signal SIGTERM was received
|
||||
if(getSigtermFlag() // SIGTERM was received
|
||||
|| state_->validated // already validated (in resumed training, for example)
|
||||
|| (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq")) && !final)) // not now
|
||||
|| (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq")) && !isFinal)) // not now
|
||||
return;
|
||||
|
||||
bool firstValidator = true;
|
||||
|
@ -34,8 +34,6 @@ public:
|
||||
|
||||
dataset->prepare();
|
||||
|
||||
auto trainState = New<TrainingState>(options_->get<float>("learn-rate"));
|
||||
auto scheduler = New<Scheduler>(options_, trainState);
|
||||
auto mpi = initMPI(/*multiThreaded=*/!options_->get<bool>("sync-sgd")); // @TODO: do we need the multiThreaded distinction at all?
|
||||
|
||||
Ptr<BatchStats> stats;
|
||||
@ -46,11 +44,20 @@ public:
|
||||
// @TODO this should receive a function object that can generate a fake batch;
|
||||
// that way vocabs would not be exposed.
|
||||
auto model = New<ModelWrapper>(options_, mpi);
|
||||
model->setScheduler(scheduler); // collectStats() needs to know about dynamic MB scaling
|
||||
|
||||
// use temporary scheduler to make sure everything gets destroyed properly
|
||||
// otherwise the scheduler believes that registered objects still exist
|
||||
auto tempTrainState = New<TrainingState>(options_->get<float>("learn-rate"));
|
||||
auto tempScheduler = New<Scheduler>(options_, tempTrainState);
|
||||
|
||||
model->setScheduler(tempScheduler); // collectStats() needs to know about dynamic MB scaling
|
||||
stats = model->collectStats(dataset->getVocabs());
|
||||
LOG(info, "[batching] Done. Typical MB size is {} target words", stats->estimateTypicalTrgWords());
|
||||
}
|
||||
|
||||
auto trainState = New<TrainingState>(options_->get<float>("learn-rate"));
|
||||
auto scheduler = New<Scheduler>(options_, trainState);
|
||||
|
||||
if((options_->hasAndNotEmpty("valid-sets") || options_->hasAndNotEmpty("valid-script-path"))
|
||||
&& SchedulingParameter::parse(options_->get<std::string>("valid-freq"))) {
|
||||
for(auto validator : Validators(dataset->getVocabs(), options_))
|
||||
@ -77,12 +84,10 @@ public:
|
||||
restored = false;
|
||||
|
||||
// main training loop for one epoch
|
||||
for(auto batchIt = std::begin(*batchGenerator); // @TODO: try to use for(auto ...)
|
||||
batchIt != std::end(*batchGenerator);
|
||||
batchIt++) {
|
||||
for(auto batch : *batchGenerator) {
|
||||
if (!scheduler->keepGoing())
|
||||
break;
|
||||
model->update(*batchIt);
|
||||
model->update(batch);
|
||||
}
|
||||
|
||||
if(scheduler->keepGoing())
|
||||
|
@ -14,6 +14,7 @@ class TrainingState;
|
||||
class TrainingObserver {
|
||||
public:
|
||||
virtual ~TrainingObserver() {}
|
||||
|
||||
virtual void init(TrainingState&) {}
|
||||
virtual void actAfterEpoch(TrainingState&) {}
|
||||
virtual void actAfterBatches(TrainingState&) {}
|
||||
@ -130,8 +131,8 @@ public:
|
||||
}
|
||||
|
||||
void registerObserver(Ptr<TrainingObserver> observer) {
|
||||
observers_.push_back(observer);
|
||||
observers_.back()->init(*this);
|
||||
observer->init(*this);
|
||||
wObservers_.push_back(observer);
|
||||
}
|
||||
|
||||
// return the totals count that corresponds to the given unit (batches, labels, or epochs)
|
||||
@ -184,8 +185,11 @@ public:
|
||||
|
||||
void newEpoch() {
|
||||
++epochs;
|
||||
for(auto observer : observers_)
|
||||
for(auto wObserver : wObservers_) {
|
||||
auto observer = wObserver.lock();
|
||||
ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
|
||||
observer->actAfterEpoch(*this);
|
||||
}
|
||||
samplesEpoch = 0;
|
||||
batchesEpoch = 0;
|
||||
}
|
||||
@ -195,22 +199,31 @@ public:
|
||||
batchesEpoch += batchesInUpdate;
|
||||
loaded = false;
|
||||
validated = false;
|
||||
for(auto observer : observers_)
|
||||
for(auto wObserver : wObservers_) {
|
||||
auto observer = wObserver.lock();
|
||||
ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
|
||||
observer->actAfterBatches(*this);
|
||||
}
|
||||
}
|
||||
|
||||
void newStalled(size_t num) {
|
||||
stalled = num;
|
||||
if(num > maxStalled)
|
||||
++maxStalled;
|
||||
for(auto observer : observers_)
|
||||
for(auto wObserver : wObservers_) {
|
||||
auto observer = wObserver.lock();
|
||||
ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
|
||||
observer->actAfterStalled(*this);
|
||||
}
|
||||
}
|
||||
|
||||
void newLoad() {
|
||||
loaded = true;
|
||||
for(auto observer : observers_)
|
||||
for(auto wObserver : wObservers_) {
|
||||
auto observer = wObserver.lock();
|
||||
ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
|
||||
observer->actAfterLoaded(*this);
|
||||
}
|
||||
}
|
||||
|
||||
void load(const std::string& name) {
|
||||
@ -303,6 +316,9 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<Ptr<TrainingObserver>> observers_;
|
||||
// this needs to be a vector of weak pointers, otherwise
|
||||
// it is likely to cause circular dependencies.
|
||||
std::vector<Weak<TrainingObserver>> wObservers_;
|
||||
|
||||
};
|
||||
} // namespace marian
|
||||
|
@ -18,8 +18,8 @@ private:
|
||||
size_t beamSize_;
|
||||
Ptr<const Vocab> trgVocab_;
|
||||
|
||||
static constexpr auto INVALID_PATH_SCORE = -9999; // (@TODO: change to -9999.0 once C++ allows that)
|
||||
static constexpr auto PURGE_BATCH = true; // @TODO: diagnostic, to-be-removed once confirmed there are no issues.
|
||||
const float INVALID_PATH_SCORE = std::numeric_limits<float>::lowest(); // @TODO: observe this closely
|
||||
const bool PURGE_BATCH = true; // @TODO: diagnostic, to-be-removed once confirmed there are no issues.
|
||||
|
||||
public:
|
||||
BeamSearch(Ptr<Options> options,
|
||||
@ -74,10 +74,20 @@ public:
|
||||
const auto currentBatchIdx = (key / vocabSize) / nBestBeamSize;
|
||||
const auto origBatchIdx = reverseBatchIdxMap.empty() ? currentBatchIdx : reverseBatchIdxMap[currentBatchIdx]; // map currentBatchIdx back into original position within starting maximal batch size, required to find correct beam
|
||||
|
||||
bool dropHyp = !dropBatchEntries.empty() && dropBatchEntries[origBatchIdx];
|
||||
|
||||
// if we force=drop the hypothesis, assign EOS, otherwise the expected word id.
|
||||
const auto wordIdx = dropHyp ? trgVocab_->getEosId().toWordIndex() : (WordIndex)(key % vocabSize);
|
||||
bool dropHyp = !dropBatchEntries.empty() && dropBatchEntries[origBatchIdx] && factorGroup == 0;
|
||||
|
||||
WordIndex wordIdx;
|
||||
if(dropHyp) { // if we force=drop the hypothesis, assign EOS, otherwise the expected word id.
|
||||
if(factoredVocab) { // when using factoredVocab, extract the EOS lemma index from the word id, we predicting factors one by one here, hence lemma only
|
||||
std::vector<size_t> eosFactors;
|
||||
factoredVocab->word2factors(factoredVocab->getEosId(), eosFactors);
|
||||
wordIdx = (WordIndex)eosFactors[0];
|
||||
} else { // without factoredVocab lemma index and word index are the same. Safe cruising.
|
||||
wordIdx = trgVocab_->getEosId().toWordIndex();
|
||||
}
|
||||
} else { // we are not dropping anything, just assign the normal index
|
||||
wordIdx = (WordIndex)(key % vocabSize);
|
||||
}
|
||||
|
||||
// @TODO: We currently assign a log probability of 0 to all beam entries of the dropped batch entry, instead it might be a good idea to use
|
||||
// the per Hyp pathScore without the current expansion (a bit hard to obtain).
|
||||
@ -88,11 +98,12 @@ public:
|
||||
const auto& beam = beams[origBatchIdx];
|
||||
auto& newBeam = newBeams[origBatchIdx]; // extended hypotheses are going to be placed in this new beam
|
||||
|
||||
if (newBeam.size() >= beam.size()) // getNBestList() generates N for all batch entries incl. those that already have a narrower beam
|
||||
if(newBeam.size() >= beam.size()) // getNBestList() generates N for all batch entries incl. those that already have a narrower beam
|
||||
continue;
|
||||
if (pathScore <= INVALID_PATH_SCORE) // (dummy slot or word that cannot be expanded by current factor)
|
||||
if(pathScore == INVALID_PATH_SCORE) // (dummy slot or word that cannot be expanded by current factor)
|
||||
continue;
|
||||
|
||||
|
||||
ABORT_IF(pathScore < INVALID_PATH_SCORE, "Actual pathScore ({}) is lower than INVALID_PATH_SCORE ({})??", pathScore, INVALID_PATH_SCORE); // This should not happen in valid situations. Currently the only smaller value would be -inf (effect of overflow in summation?)
|
||||
ABORT_IF(beamHypIdx >= beam.size(), "Out of bounds beamHypIdx??"); // effectively this is equivalent to ABORT_IF(beams[origBatchIdx].empty(), ...)
|
||||
|
||||
// map wordIdx to word
|
||||
|
@ -55,7 +55,11 @@ public:
|
||||
return nbest;
|
||||
}
|
||||
|
||||
Result top() const { return nBest(1)[0]; }
|
||||
Result top() const {
|
||||
const NBestList& nbest = nBest(1);
|
||||
ABORT_IF(nbest.empty(), "No hypotheses in n-best list??");
|
||||
return nbest[0];
|
||||
}
|
||||
|
||||
size_t getLineNum() const { return lineNo_; }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user