sync with internal branch

This commit is contained in:
Marcin Junczys-Dowmunt 2020-03-06 20:54:40 -08:00
commit f4ea8239c4
9 changed files with 74 additions and 81 deletions

View File

@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Gradient-checkpointing
### Fixed
- Replace value for INVALID_PATH_SCORE with std::numer_limits<float>::lowest()
to avoid overflow with long sequences
- Break up potential circular references for GraphGroup*
- Fix empty source batch entries with batch purging
- Clear RNN chache in transformer model, add correct hash functions to nodes
- Gather-operation for all index sizes
@ -68,6 +71,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Dropped support for g++-4.9
- Simplified file stream and temporary file handling
- Unified node intializers, same function API.
- Remove overstuff/understuff code
## [1.8.0] - 2019-09-04

View File

@ -772,12 +772,6 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) {
{"0"});
cli.add<bool>("--mini-batch-track-lr",
"Dynamically track mini-batch size inverse to actual learning rate (not considering lr-warmup)");
cli.add<size_t>("--mini-batch-overstuff",
"[experimental] Stuff this much more data into a minibatch, but scale down the LR and progress counter",
1);
cli.add<size_t>("--mini-batch-understuff",
"[experimental] Break each batch into this many updates",
1);
}
// clang-format on
}

View File

@ -146,7 +146,7 @@ static double roundUpRatio(double ratio) {
// helper routine that handles accumulation and load-balancing of sub-batches to fill all devices
// It adds 'newBatch' to 'pendingBatches_', and if sufficient batches have been queued, then
// returns 'pendingBatches_' in 'subBatches' and resets it. If not, it returns false.
bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuff,
bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch,
std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches) {
// The reader delivers in chunks of these sizes, according to case:
// - no dynamic MB-size scaling:
@ -199,9 +199,6 @@ bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuf
ratio *= (double)refBatchLabels / (double)(typicalTrgBatchWords_ * updateMultiplier_);
}
// overstuff: blow up ratio by a factor, which we later factor into the learning rate
ratio *= (double)overstuff;
// round up to full batches if within a certain error margin --@BUGBUG: Not invariant w.r.t. GPU size, as ratio is relative to what fits into 1 GPU
ratio = roundUpRatio(ratio);
@ -267,41 +264,18 @@ bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuf
void SyncGraphGroup::update(Ptr<data::Batch> newBatch) /*override*/ {
validate();
size_t overstuff = options_->get<size_t>("mini-batch-overstuff");
if (overstuff != 1)
LOG_ONCE(info, "Overstuffing minibatches by a factor of {}", overstuff);
std::vector<Ptr<data::Batch>> subBatches;
size_t numReadBatches; // actual #batches delivered by reader, for restoring from checkpoint --@TODO: reader should checkpoint itself; should not go via the scheduler
bool gotSubBatches = tryGetSubBatches(newBatch, overstuff, subBatches, numReadBatches);
bool gotSubBatches = tryGetSubBatches(newBatch, subBatches, numReadBatches);
// not enough data yet: return right away
if (!gotSubBatches)
return;
// for testing the hypothesis that one can always go smaller. This is independent of overstuff.
size_t understuff = options_->get<size_t>("mini-batch-understuff");
if (understuff != 1)
LOG_ONCE(info, "Understuffing minibatches by a factor of {}", understuff);
if (understuff == 1)
update(subBatches, numReadBatches);
else {
std::vector<Ptr<data::Batch>> subBatches1;
for (auto& b : subBatches) {
auto bbs = b->split(understuff);
for (auto& bb : bbs)
subBatches1.push_back(bb);
}
for (size_t i = 0; i < understuff; i++) {
std::vector<Ptr<data::Batch>> subBatchRange(subBatches1.begin() + i * subBatches1.size() / understuff, subBatches1.begin() + (i+1) * subBatches1.size() / understuff);
if (!subBatchRange.empty())
update(subBatchRange, numReadBatches * (i+1) / understuff - numReadBatches * i / understuff);
}
}
update(subBatches, numReadBatches);
}
void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t numReadBatches) {
size_t overstuff = options_->get<size_t>("mini-batch-overstuff");
//size_t understuff = options_->get<size_t>("mini-batch-understuff");
// determine num words for dynamic hyper-parameter adjustment
// @TODO: We can return these directly from tryGetSubBatches()
size_t batchSize = 0;
@ -310,9 +284,6 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
batchSize += batch->size();
batchTrgWords += batch->wordsTrg();
}
// effective batch size: batch should be weighted like this. This will weight down the learning rate.
size_t effectiveBatchTrgWords = (size_t)ceil(batchTrgWords / (double)overstuff);
size_t effectiveBatchSize = (size_t)ceil(batchSize / (double)overstuff);
// Helper to access the subBatches array
auto getSubBatch = [&](size_t warp, size_t localDeviceIndex, size_t rank) -> Ptr<data::Batch> {
@ -353,32 +324,21 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
auto rationalLoss = builders_[localDeviceIndex]->build(graph, subBatch);
graph->forward();
StaticLoss tempLoss = *rationalLoss; // needed for overstuff
tempLoss.loss /= (float)overstuff; // @TODO: @fseide: scale only loss? should this scale labels too?
localDeviceLosses[localDeviceIndex] += tempLoss;
localDeviceLosses[localDeviceIndex] += *rationalLoss;
graph->backward(/*zero=*/false); // (gradients are reset before we get here)
}
});
// At this point, each device on each MPI process has a gradient aggregated over a subset of the sub-batches.
// only needed for overstuff now
float div = (float)overstuff; // (note: with Adam, a constant here makes no difference)
// Update parameter shard with gradient shard
auto update = [&](size_t idx, size_t begin, size_t end) {
auto curGrad = graphs_[idx]->params()->grads()->subtensor(begin, end-begin);
auto curParam = graphs_[idx]->params()->vals()->subtensor(begin, end-begin);
if(div != 1.f) {
using namespace functional;
Element(_1 = _1 / div, curGrad); // average if overstuffed
}
// actual model update
auto updateTrgWords =
/*if*/(options_->get<std::string>("cost-type") == "ce-sum") ?
effectiveBatchTrgWords // if overstuffing then bring the count back to the original value
batchTrgWords
/*else*/:
OptimizerBase::mbSizeNotProvided;
shardOpt_[idx]->update(curParam, curGrad, updateTrgWords);
@ -405,7 +365,7 @@ void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t num
if(scheduler_) {
// track and log localLoss
scheduler_->update(localLoss, numReadBatches, effectiveBatchSize, effectiveBatchTrgWords, mpi_);
scheduler_->update(localLoss, numReadBatches, batchSize, batchTrgWords, mpi_);
// save intermediate model (and optimizer state) to file
if(scheduler_->saving())

2
src/training/graph_group_sync.h Normal file → Executable file
View File

@ -35,7 +35,7 @@ class SyncGraphGroup : public GraphGroup, public ExponentialSmoothing {
void barrier() const { mpi_->barrier(); } // (we need this several times)
void swapParamsAvg() { if (mvAvg_ && paramsAvg_.size() > 0) comm_->swapParams(paramsAvg_); } // note: must call this on all MPI ranks in parallel
bool tryGetSubBatches(Ptr<data::Batch> newBatch, size_t overstuff, std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches);
bool tryGetSubBatches(Ptr<data::Batch> newBatch, std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches);
void update(std::vector<Ptr<data::Batch>> subBatches, size_t numReadBatches);
public:

View File

@ -196,8 +196,7 @@ public:
registerTrainingObserver(validators_.back());
if(!state_->loaded) {
state_->validators[validator->type()]["last-best"]
= validator->initScore();
state_->validators[validator->type()]["last-best"] = validator->initScore();
state_->validators[validator->type()]["stalled"] = 0;
}
if(validators_.size() == 1)
@ -215,12 +214,12 @@ public:
}
void validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
bool final = false) {
bool isFinal = false) {
// Do not validate if already validated (for instance, after the model is
// loaded) or if validation is scheduled for another update, or when signal SIGTERM was received
if(getSigtermFlag() // SIGTERM was received
|| state_->validated // already validated (in resumed training, for example)
|| (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq")) && !final)) // not now
|| (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq")) && !isFinal)) // not now
return;
bool firstValidator = true;

View File

@ -34,8 +34,6 @@ public:
dataset->prepare();
auto trainState = New<TrainingState>(options_->get<float>("learn-rate"));
auto scheduler = New<Scheduler>(options_, trainState);
auto mpi = initMPI(/*multiThreaded=*/!options_->get<bool>("sync-sgd")); // @TODO: do we need the multiThreaded distinction at all?
Ptr<BatchStats> stats;
@ -46,11 +44,20 @@ public:
// @TODO this should receive a function object that can generate a fake batch;
// that way vocabs would not be exposed.
auto model = New<ModelWrapper>(options_, mpi);
model->setScheduler(scheduler); // collectStats() needs to know about dynamic MB scaling
// use temporary scheduler to make sure everything gets destroyed properly
// otherwise the scheduler believes that registered objects still exist
auto tempTrainState = New<TrainingState>(options_->get<float>("learn-rate"));
auto tempScheduler = New<Scheduler>(options_, tempTrainState);
model->setScheduler(tempScheduler); // collectStats() needs to know about dynamic MB scaling
stats = model->collectStats(dataset->getVocabs());
LOG(info, "[batching] Done. Typical MB size is {} target words", stats->estimateTypicalTrgWords());
}
auto trainState = New<TrainingState>(options_->get<float>("learn-rate"));
auto scheduler = New<Scheduler>(options_, trainState);
if((options_->hasAndNotEmpty("valid-sets") || options_->hasAndNotEmpty("valid-script-path"))
&& SchedulingParameter::parse(options_->get<std::string>("valid-freq"))) {
for(auto validator : Validators(dataset->getVocabs(), options_))
@ -77,12 +84,10 @@ public:
restored = false;
// main training loop for one epoch
for(auto batchIt = std::begin(*batchGenerator); // @TODO: try to use for(auto ...)
batchIt != std::end(*batchGenerator);
batchIt++) {
for(auto batch : *batchGenerator) {
if (!scheduler->keepGoing())
break;
model->update(*batchIt);
model->update(batch);
}
if(scheduler->keepGoing())

View File

@ -14,6 +14,7 @@ class TrainingState;
class TrainingObserver {
public:
virtual ~TrainingObserver() {}
virtual void init(TrainingState&) {}
virtual void actAfterEpoch(TrainingState&) {}
virtual void actAfterBatches(TrainingState&) {}
@ -130,8 +131,8 @@ public:
}
void registerObserver(Ptr<TrainingObserver> observer) {
observers_.push_back(observer);
observers_.back()->init(*this);
observer->init(*this);
wObservers_.push_back(observer);
}
// return the totals count that corresponds to the given unit (batches, labels, or epochs)
@ -184,8 +185,11 @@ public:
void newEpoch() {
++epochs;
for(auto observer : observers_)
for(auto wObserver : wObservers_) {
auto observer = wObserver.lock();
ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
observer->actAfterEpoch(*this);
}
samplesEpoch = 0;
batchesEpoch = 0;
}
@ -195,22 +199,31 @@ public:
batchesEpoch += batchesInUpdate;
loaded = false;
validated = false;
for(auto observer : observers_)
for(auto wObserver : wObservers_) {
auto observer = wObserver.lock();
ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
observer->actAfterBatches(*this);
}
}
void newStalled(size_t num) {
stalled = num;
if(num > maxStalled)
++maxStalled;
for(auto observer : observers_)
for(auto wObserver : wObservers_) {
auto observer = wObserver.lock();
ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
observer->actAfterStalled(*this);
}
}
void newLoad() {
loaded = true;
for(auto observer : observers_)
for(auto wObserver : wObservers_) {
auto observer = wObserver.lock();
ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
observer->actAfterLoaded(*this);
}
}
void load(const std::string& name) {
@ -303,6 +316,9 @@ public:
}
private:
std::vector<Ptr<TrainingObserver>> observers_;
// this needs to be a vector of weak pointers, otherwise
// it is likely to cause circular dependencies.
std::vector<Weak<TrainingObserver>> wObservers_;
};
} // namespace marian

View File

@ -18,8 +18,8 @@ private:
size_t beamSize_;
Ptr<const Vocab> trgVocab_;
static constexpr auto INVALID_PATH_SCORE = -9999; // (@TODO: change to -9999.0 once C++ allows that)
static constexpr auto PURGE_BATCH = true; // @TODO: diagnostic, to-be-removed once confirmed there are no issues.
const float INVALID_PATH_SCORE = std::numeric_limits<float>::lowest(); // @TODO: observe this closely
const bool PURGE_BATCH = true; // @TODO: diagnostic, to-be-removed once confirmed there are no issues.
public:
BeamSearch(Ptr<Options> options,
@ -74,10 +74,20 @@ public:
const auto currentBatchIdx = (key / vocabSize) / nBestBeamSize;
const auto origBatchIdx = reverseBatchIdxMap.empty() ? currentBatchIdx : reverseBatchIdxMap[currentBatchIdx]; // map currentBatchIdx back into original position within starting maximal batch size, required to find correct beam
bool dropHyp = !dropBatchEntries.empty() && dropBatchEntries[origBatchIdx];
// if we force=drop the hypothesis, assign EOS, otherwise the expected word id.
const auto wordIdx = dropHyp ? trgVocab_->getEosId().toWordIndex() : (WordIndex)(key % vocabSize);
bool dropHyp = !dropBatchEntries.empty() && dropBatchEntries[origBatchIdx] && factorGroup == 0;
WordIndex wordIdx;
if(dropHyp) { // if we force=drop the hypothesis, assign EOS, otherwise the expected word id.
if(factoredVocab) { // when using factoredVocab, extract the EOS lemma index from the word id, we predicting factors one by one here, hence lemma only
std::vector<size_t> eosFactors;
factoredVocab->word2factors(factoredVocab->getEosId(), eosFactors);
wordIdx = (WordIndex)eosFactors[0];
} else { // without factoredVocab lemma index and word index are the same. Safe cruising.
wordIdx = trgVocab_->getEosId().toWordIndex();
}
} else { // we are not dropping anything, just assign the normal index
wordIdx = (WordIndex)(key % vocabSize);
}
// @TODO: We currently assign a log probability of 0 to all beam entries of the dropped batch entry, instead it might be a good idea to use
// the per Hyp pathScore without the current expansion (a bit hard to obtain).
@ -88,11 +98,12 @@ public:
const auto& beam = beams[origBatchIdx];
auto& newBeam = newBeams[origBatchIdx]; // extended hypotheses are going to be placed in this new beam
if (newBeam.size() >= beam.size()) // getNBestList() generates N for all batch entries incl. those that already have a narrower beam
if(newBeam.size() >= beam.size()) // getNBestList() generates N for all batch entries incl. those that already have a narrower beam
continue;
if (pathScore <= INVALID_PATH_SCORE) // (dummy slot or word that cannot be expanded by current factor)
if(pathScore == INVALID_PATH_SCORE) // (dummy slot or word that cannot be expanded by current factor)
continue;
ABORT_IF(pathScore < INVALID_PATH_SCORE, "Actual pathScore ({}) is lower than INVALID_PATH_SCORE ({})??", pathScore, INVALID_PATH_SCORE); // This should not happen in valid situations. Currently the only smaller value would be -inf (effect of overflow in summation?)
ABORT_IF(beamHypIdx >= beam.size(), "Out of bounds beamHypIdx??"); // effectively this is equivalent to ABORT_IF(beams[origBatchIdx].empty(), ...)
// map wordIdx to word

View File

@ -55,7 +55,11 @@ public:
return nbest;
}
Result top() const { return nBest(1)[0]; }
Result top() const {
const NBestList& nbest = nBest(1);
ABORT_IF(nbest.empty(), "No hypotheses in n-best list??");
return nbest[0];
}
size_t getLineNum() const { return lineNo_; }