mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-19 18:59:18 +03:00
fix delay between saving and validation for async SGD
This commit is contained in:
parent
69b37bae7e
commit
958222ba9c
@ -223,14 +223,8 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) {
|
|||||||
|
|
||||||
scheduler_->update(cost, batch);
|
scheduler_->update(cost, batch);
|
||||||
|
|
||||||
if(scheduler_->saving()) {
|
if(scheduler_->saving() || scheduler_->validating()) {
|
||||||
if(movingAvg_)
|
// Wait with validation or saving until all other threads are done with update.
|
||||||
fetchParams(graph->params()->vals(), paramsAvg_, t_id);
|
|
||||||
this->save(graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(scheduler_->validating()) {
|
|
||||||
// Wait with validation until all other threads are done with update.
|
|
||||||
// We want to reuse the graphs for validation, so they need to be in
|
// We want to reuse the graphs for validation, so they need to be in
|
||||||
// a safe state.
|
// a safe state.
|
||||||
pool_.wait_for_others(lock);
|
pool_.wait_for_others(lock);
|
||||||
@ -238,9 +232,14 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) {
|
|||||||
if(movingAvg_)
|
if(movingAvg_)
|
||||||
for(auto g : graphs_)
|
for(auto g : graphs_)
|
||||||
fetchParams(g->params()->vals(), paramsAvg_, t_id);
|
fetchParams(g->params()->vals(), paramsAvg_, t_id);
|
||||||
scheduler_->validate(graphs_);
|
|
||||||
|
|
||||||
// Validation is done, tell other threads to continue work.
|
if(scheduler_->saving())
|
||||||
|
this->save(graph);
|
||||||
|
|
||||||
|
if(scheduler_->validating())
|
||||||
|
scheduler_->validate(graphs_);
|
||||||
|
|
||||||
|
// Validation or saving is done, tell other threads to continue work.
|
||||||
pool_.notify_others();
|
pool_.notify_others();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user