mirror of
https://github.com/marian-nmt/marian.git
synced 2024-10-05 19:17:10 +03:00
Merged PR 25686: Loading checkpoints from main node only via MPI
Enables loading of model checkpoints from main node only via MPI. Until now the checkpoint needed to present in the same location on all nodes. That could be done either via writing to a shared filesystem (problematic due to bad syncing) or by manual copying to the same local location, e.g. /tmp on each node (while writing only happened to one main location). Now, marian can resume training from only one location on the main node. The remaining nodes do not need to have access. E.g. local /tmp on the main node can be used, or race conditons on shared storage are avoided. Also avoids creating files for logging on more than one node. This is a bit wonky, done via environment variable lookup.
This commit is contained in:
parent
76964791ad
commit
7d2045a907
@ -9,8 +9,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- `--force-decode` option for marian-decoder
|
||||
- `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`)
|
||||
|
||||
### Fixed
|
||||
- Read/restore checkpoints from main process only when training with MPI
|
||||
- Multi-loss casts type to first loss-type before accumulation (aborted before due to missing cast)
|
||||
- Throw `ShapeSizeException` if total expanded shape size exceeds numeric capacity of the maximum int value (2^31-1)
|
||||
- During mini-batch-fitting, catch `ShapeSizeException` and use another sizing hint. Aborts outside mini-batch-fitting.
|
||||
|
@ -19,24 +19,6 @@ static inline std::string interpolateEnvVars(std::string str) {
|
||||
return str;
|
||||
}
|
||||
|
||||
#if 1
|
||||
if(getenv("PHILLY_JOB_ID")) {
|
||||
const char* cluster = getenv("PHILLY_CLUSTER");
|
||||
const char* vc = getenv("PHILLY_VC");
|
||||
// this environment variable exists when running on the cluster
|
||||
if(cluster && vc) {
|
||||
static const std::string s_gfsPrefix
|
||||
= std::string("/gfs/") + cluster + "/" + vc + "/";
|
||||
static const std::string s_hdfsPrefix
|
||||
= std::string("/hdfs/") + cluster + "/" + vc + "/";
|
||||
if(str.find(s_gfsPrefix) == 0)
|
||||
str = std::string("/hdfs/") + vc + "/" + str.substr(s_gfsPrefix.size());
|
||||
else if(str.find(s_hdfsPrefix) == 0)
|
||||
str = std::string("/hdfs/") + vc + "/"
|
||||
+ str.substr(s_hdfsPrefix.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
for(;;) {
|
||||
const auto pos = str.find("${");
|
||||
if(pos == std::string::npos)
|
||||
|
@ -74,6 +74,12 @@ std::string InputFileStream::getFileName() const {
|
||||
return file_.string();
|
||||
}
|
||||
|
||||
std::string InputFileStream::readToString() const {
|
||||
std::stringstream ss;
|
||||
ss << this->rdbuf();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
// wrapper around std::getline() that handles Windows input files with extra CR
|
||||
// chars at the line end
|
||||
std::istream &getline(std::istream &in, std::string &line) {
|
||||
@ -85,6 +91,7 @@ std::istream &getline(std::istream &in, std::string &line) {
|
||||
line.pop_back();
|
||||
return in;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////
|
||||
OutputFileStream::OutputFileStream(const std::string &file)
|
||||
: std::ostream(NULL), file_(file) {
|
||||
@ -119,7 +126,7 @@ TemporaryFile::TemporaryFile(const std::string &base, bool earlyUnlink)
|
||||
NormalizeTempPrefix(baseTemp);
|
||||
MakeTemp(baseTemp);
|
||||
|
||||
inSteam_ = UPtr<io::InputFileStream>(new io::InputFileStream(file_.string()));
|
||||
inStream_ = UPtr<io::InputFileStream>(new io::InputFileStream(file_.string()));
|
||||
if(unlink_) {
|
||||
ABORT_IF(remove(file_.string().c_str()), "Error while deleting '{}'", file_.string());
|
||||
}
|
||||
@ -190,7 +197,7 @@ void TemporaryFile::MakeTemp(const std::string &base) {
|
||||
}
|
||||
|
||||
UPtr<InputFileStream> TemporaryFile::getInputStream() {
|
||||
return std::move(inSteam_);
|
||||
return std::move(inStream_);
|
||||
}
|
||||
|
||||
std::string TemporaryFile::getFileName() const {
|
||||
|
@ -46,6 +46,7 @@ public:
|
||||
bool empty();
|
||||
void setbufsize(size_t size);
|
||||
std::string getFileName() const;
|
||||
std::string readToString() const;
|
||||
|
||||
protected:
|
||||
marian::filesystem::Path file_;
|
||||
@ -92,7 +93,7 @@ public:
|
||||
|
||||
protected:
|
||||
bool unlink_;
|
||||
UPtr<InputFileStream> inSteam_;
|
||||
UPtr<InputFileStream> inStream_;
|
||||
|
||||
void NormalizeTempPrefix(std::string& base) const;
|
||||
void MakeTemp(const std::string& base);
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include "logging.h"
|
||||
#include "common/config.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
#include "spdlog/sinks/null_sink.h"
|
||||
#include "3rd_party/ExceptionWithCallStack.h"
|
||||
#include <time.h>
|
||||
@ -30,9 +32,14 @@ std::shared_ptr<spdlog::logger> createStderrLogger(const std::string& name,
|
||||
if(!quiet)
|
||||
sinks.push_back(stderr_sink);
|
||||
|
||||
for(auto&& file : files) {
|
||||
auto file_sink = std::make_shared<spdlog::sinks::simple_file_sink_st>(file, true);
|
||||
sinks.push_back(file_sink);
|
||||
// @TODO: think how to solve this better than using OMPI_COMM_WORLD_RANK env variable
|
||||
// only create output files if we are the main process or if MPI rank is not defined
|
||||
int rank = marian::utils::getMPIRankEnv(); // this function looks up OMPI_COMM_WORLD_RANK env variable
|
||||
if(rank == 0) {
|
||||
for(auto&& file : files) {
|
||||
auto file_sink = std::make_shared<spdlog::sinks::simple_file_sink_st>(file, true);
|
||||
sinks.push_back(file_sink);
|
||||
}
|
||||
}
|
||||
|
||||
auto logger = std::make_shared<spdlog::logger>(name, begin(sinks), end(sinks));
|
||||
|
@ -180,7 +180,8 @@ std::string exec(const std::string& cmd, const std::vector<std::string>& args /*
|
||||
|
||||
std::pair<std::string, int> hostnameAndProcessId() { // helper to get hostname:pid
|
||||
#ifdef _WIN32
|
||||
std::string hostname = getenv("COMPUTERNAME");
|
||||
const char* res = getenv("COMPUTERNAME");
|
||||
std::string hostname = res ? std::string(res) : "";
|
||||
auto processId = (int)GetCurrentProcessId();
|
||||
#else
|
||||
static std::string hostname = []() { // not sure if gethostname() is expensive. This way we call it only once.
|
||||
@ -193,6 +194,15 @@ std::pair<std::string, int> hostnameAndProcessId() { // helper to get hostname:
|
||||
return {hostname, processId};
|
||||
}
|
||||
|
||||
// returns MPI rank from environment variable if set, otherwise 0
|
||||
int getMPIRankEnv() {
|
||||
const char* rank = getenv("OMPI_COMM_WORLD_RANK");
|
||||
if(rank)
|
||||
return std::atoi(rank);
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
// format a long number with comma separators
|
||||
std::string withCommas(size_t n) {
|
||||
std::string res = std::to_string(n);
|
||||
|
@ -43,6 +43,9 @@ std::string exec(const std::string& cmd, const std::vector<std::string>& args =
|
||||
|
||||
std::pair<std::string, int> hostnameAndProcessId();
|
||||
|
||||
// returns MPI rank from environment variable if set, otherwise 0
|
||||
int getMPIRankEnv();
|
||||
|
||||
std::string withCommas(size_t n);
|
||||
bool beginsWith(const std::string& text, const std::string& prefix);
|
||||
bool endsWith(const std::string& text, const std::string& suffix);
|
||||
|
@ -141,16 +141,6 @@ public:
|
||||
output->Write((long)batch->getSentenceIds()[i],
|
||||
sentVector);
|
||||
}
|
||||
|
||||
// progress heartbeat for MS-internal Philly compute cluster
|
||||
// otherwise this job may be killed prematurely if no log for 4 hrs
|
||||
if (getenv("PHILLY_JOB_ID") // this environment variable exists when running on the cluster
|
||||
&& id % 1000 == 0) // hard beat once every 1000 batches
|
||||
{
|
||||
auto progress = id / 10000.f; //fake progress for now, becomes >100 after 1M batches
|
||||
fprintf(stderr, "PROGRESS: %.2f%%\n", progress);
|
||||
fflush(stderr);
|
||||
}
|
||||
};
|
||||
|
||||
pool.enqueue(task, batchId++);
|
||||
|
@ -75,6 +75,10 @@ public:
|
||||
return Logits(apply(graph, batch, inference_));
|
||||
}
|
||||
|
||||
void load(Ptr<ExpressionGraph> /*graph*/, const std::vector<io::Item>& /*items*/, bool) override {
|
||||
LOG(critical, "Loading MNIST model is not supported");
|
||||
}
|
||||
|
||||
void load(Ptr<ExpressionGraph> /*graph*/, const std::string& /*name*/, bool) override {
|
||||
LOG(critical, "Loading MNIST model is not supported");
|
||||
}
|
||||
|
@ -217,6 +217,12 @@ public:
|
||||
|
||||
Ptr<IModel> getModel() { return model_; }
|
||||
|
||||
void load(Ptr<ExpressionGraph> graph,
|
||||
const std::vector<io::Item>& items,
|
||||
bool markedReloaded) override {
|
||||
model_->load(graph, items, markedReloaded);
|
||||
}
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool markedReloaded = true) override {
|
||||
@ -263,6 +269,12 @@ public:
|
||||
|
||||
Ptr<IModel> getModel() { return model_; }
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::vector<io::Item>& items,
|
||||
bool markReloaded = true) override {
|
||||
model_->load(graph, items, markReloaded);
|
||||
}
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool markedReloaded = true) override {
|
||||
|
@ -152,6 +152,12 @@ public:
|
||||
void push_back(Ptr<EncoderBase> encoder) { encoders_.push_back(encoder); }
|
||||
void push_back(Ptr<ClassifierBase> classifier) { classifiers_.push_back(classifier); }
|
||||
|
||||
void load(Ptr<ExpressionGraph> graph,
|
||||
const std::vector<io::Item>& items,
|
||||
bool markedReloaded) override {
|
||||
graph->load(items, markedReloaded && !opt<bool>("ignore-model-config", false));
|
||||
}
|
||||
|
||||
void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool markedReloaded) override {
|
||||
|
@ -15,7 +15,8 @@ public:
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::vector<io::Item>& items,
|
||||
bool markedReloaded = true) = 0;
|
||||
bool markedReloaded = true) override
|
||||
= 0;
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
|
@ -25,6 +25,11 @@ class EncoderPoolerBase : public models::IModel {
|
||||
public:
|
||||
virtual ~EncoderPoolerBase() {}
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::vector<io::Item>& items,
|
||||
bool markedReloaded = true) override
|
||||
= 0;
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool markedReloaded = true) override
|
||||
@ -162,6 +167,12 @@ public:
|
||||
void push_back(Ptr<EncoderBase> encoder) { encoders_.push_back(encoder); }
|
||||
void push_back(Ptr<PoolerBase> pooler) { poolers_.push_back(pooler); }
|
||||
|
||||
void load(Ptr<ExpressionGraph> graph,
|
||||
const std::vector<io::Item>& items,
|
||||
bool markedReloaded) override {
|
||||
graph->load(items, markedReloaded && !opt<bool>("ignore-model-config", false));
|
||||
}
|
||||
|
||||
void load(Ptr<ExpressionGraph> graph,
|
||||
const std::string& name,
|
||||
bool markedReloaded) override {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <string>
|
||||
#include "marian.h"
|
||||
#include "common/io_item.h"
|
||||
#include "layers/loss.h"
|
||||
#include "layers/generic.h"
|
||||
|
||||
@ -24,6 +25,12 @@ public:
|
||||
const std::string&,
|
||||
bool markReloaded = true)
|
||||
= 0;
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph>,
|
||||
const std::vector<io::Item>&,
|
||||
bool markReloaded = true)
|
||||
= 0;
|
||||
|
||||
virtual void save(Ptr<ExpressionGraph>,
|
||||
const std::string&,
|
||||
bool saveTranslatorConfig = false)
|
||||
@ -47,6 +54,12 @@ public:
|
||||
const std::string&,
|
||||
bool markReloaded = true)
|
||||
= 0;
|
||||
|
||||
virtual void load(Ptr<ExpressionGraph>,
|
||||
const std::vector<io::Item>&,
|
||||
bool markReloaded = true)
|
||||
= 0;
|
||||
|
||||
virtual void save(Ptr<ExpressionGraph>,
|
||||
const std::string&,
|
||||
bool saveTranslatorConfig = false)
|
||||
|
@ -201,16 +201,6 @@ public:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// progress heartbeat for MS-internal Philly compute cluster
|
||||
// otherwise this job may be killed prematurely if no log for 4 hrs
|
||||
if (getenv("PHILLY_JOB_ID") // this environment variable exists when running on the cluster
|
||||
&& id % 1000 == 0) // hard beat once every 1000 batches
|
||||
{
|
||||
auto progress = id / 10000.f; //fake progress for now, becomes >100 after 1M batches
|
||||
fprintf(stdout, "PROGRESS: %.2f%%\n", progress);
|
||||
fflush(stdout);
|
||||
}
|
||||
};
|
||||
|
||||
pool.enqueue(task, batchId++);
|
||||
|
@ -134,11 +134,12 @@ public:
|
||||
|
||||
// get the limit for int count
|
||||
size_t limit = (size_t)std::numeric_limits<int>::max();
|
||||
size_t remaining = count, offset = 0;
|
||||
size_t remaining = count;
|
||||
size_t offset = 0;
|
||||
|
||||
// while there are elements that we have not sent yet, loop until all has been sent in chunks of at most `limit`.
|
||||
while(remaining > 0) {
|
||||
int intCount = (int)std::min(remaining, limit);
|
||||
int intCount = (int)std::min(remaining, limit);
|
||||
HANDLE_MPI_ERROR(MPI_Bcast((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)rootRank, comm));
|
||||
offset += (size_t)intCount;
|
||||
remaining -= (size_t)intCount;
|
||||
@ -193,6 +194,49 @@ public:
|
||||
virtual void finalize() override {
|
||||
HANDLE_MPI_ERROR(MPI_Finalize());
|
||||
}
|
||||
|
||||
virtual void bCast(io::Item& item, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
|
||||
if(isMainProcess())
|
||||
ABORT_IF(item.bytes.empty(), "Broadcasting empty item via MPI should not happen. Please report.");
|
||||
|
||||
unsigned long long bytesLen = item.bytes.size();
|
||||
bCast(&bytesLen, 1, getDataType(&bytesLen), rootRank, comm);
|
||||
|
||||
item.bytes.resize(bytesLen);
|
||||
bCast(item.bytes.data(), bytesLen, getDataType(item.bytes.data()), rootRank, comm);
|
||||
|
||||
unsigned long long shapeLen = item.shape.size();
|
||||
bCast(&shapeLen, 1, getDataType(&shapeLen), rootRank, comm);
|
||||
item.shape.resize(shapeLen);
|
||||
bCast(item.shape.data(), shapeLen, getDataType(item.shape.data()), rootRank, comm);
|
||||
|
||||
bCast(item.name, rootRank, comm);
|
||||
|
||||
size_t type = (size_t)item.type;
|
||||
bCast(&type, 1, getDataType(&type), rootRank, comm);
|
||||
item.type = (Type)type;
|
||||
}
|
||||
|
||||
virtual void bCast(std::vector<io::Item>& items, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
|
||||
size_t numItems = 0;
|
||||
if(isMainProcess())
|
||||
numItems = items.size();
|
||||
|
||||
bCast(&numItems, 1, getDataType(&numItems), rootRank, comm);
|
||||
items.resize(numItems);
|
||||
for(auto& item : items)
|
||||
bCast(item, rootRank, comm);
|
||||
}
|
||||
|
||||
virtual void bCast(std::string& str, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
|
||||
size_t length = 0;
|
||||
if(isMainProcess())
|
||||
length = str.size();
|
||||
|
||||
bCast(&length, 1, getDataType(&length), rootRank, comm);
|
||||
str.resize(length);
|
||||
bCast(str.data(), length, getDataType(str.data()), rootRank, comm);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -232,6 +276,19 @@ public:
|
||||
// to only accept one parameter, and remove this error check can be removed.
|
||||
ABORT_IF(sendbuf != recvbuf, "FakeMPIWrapper::allReduce() only implemented for in-place operation"); // otherwise it's not a no-op, we must copy data
|
||||
}
|
||||
|
||||
virtual void bCast(io::Item& item, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
|
||||
item; rootRank; comm;
|
||||
}
|
||||
|
||||
virtual void bCast(std::vector<io::Item>& items, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
|
||||
items; rootRank; comm;
|
||||
}
|
||||
|
||||
virtual void bCast(std::string& str, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
|
||||
str; rootRank; comm;
|
||||
}
|
||||
|
||||
#pragma warning(pop)
|
||||
virtual void finalize() override { }
|
||||
};
|
||||
|
@ -68,7 +68,7 @@ public:
|
||||
#if MPI_FOUND
|
||||
#else
|
||||
enum MPI_Comm { MPI_COMM_WORLD };
|
||||
enum MPI_Datatype { MPI_FLOAT, MPI_UNSIGNED_LONG_LONG, MPI_UNSIGNED_LONG, MPI_BYTE, MPI_INT };
|
||||
enum MPI_Datatype { MPI_FLOAT, MPI_UNSIGNED_LONG_LONG, MPI_UNSIGNED_LONG, MPI_BYTE, MPI_INT, MPI_CXX_BOOL };
|
||||
enum MPI_Op { MPI_SUM };
|
||||
struct MPI_Status { int MPI_SOURCE; };
|
||||
#define MPI_ANY_SOURCE ((size_t)-2)
|
||||
@ -88,30 +88,16 @@ struct/*interface*/ IMPIWrapper {
|
||||
static const size_t RECV_ANY_SOURCE = (size_t)MPI_ANY_SOURCE;
|
||||
|
||||
static MPI_Datatype getDataType(const char*) { return MPI_BYTE; }
|
||||
static MPI_Datatype getDataType(const bool*) { return MPI_CXX_BOOL; }
|
||||
static MPI_Datatype getDataType(const int*) { return MPI_INT; }
|
||||
static MPI_Datatype getDataType(const float*) { return MPI_FLOAT; }
|
||||
static MPI_Datatype getDataType(const unsigned long*) { return MPI_UNSIGNED_LONG; }
|
||||
static MPI_Datatype getDataType(const unsigned long long*) { return MPI_UNSIGNED_LONG_LONG; }
|
||||
|
||||
void bCast(io::Item& item, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) {
|
||||
ABORT_IF(item.bytes.empty(), "Broadcasting empty item via MPI??");
|
||||
|
||||
unsigned long long bytesLen = item.bytes.size();
|
||||
bCast(&bytesLen, 1, getDataType(&bytesLen), rootRank, comm);
|
||||
|
||||
item.bytes.resize(bytesLen);
|
||||
bCast(item.bytes.data(), item.bytes.size(), getDataType(item.bytes.data()), rootRank, comm);
|
||||
|
||||
unsigned long long shapeLen = item.shape.size();
|
||||
bCast(&shapeLen, 1, getDataType(&shapeLen), rootRank, comm);
|
||||
|
||||
bCast(item.shape.data(), item.shape.size(), getDataType(item.shape.data()), rootRank, comm);
|
||||
|
||||
size_t type = (size_t)item.type;
|
||||
bCast(&type, 1, getDataType(&type), rootRank, comm);
|
||||
item.type = (Type)type;
|
||||
}
|
||||
|
||||
virtual void bCast(io::Item& item, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const = 0;
|
||||
virtual void bCast(std::vector<io::Item>& items, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const = 0;
|
||||
virtual void bCast(std::string& str, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const = 0;
|
||||
|
||||
std::string idStr() const;
|
||||
};
|
||||
|
||||
|
@ -283,25 +283,51 @@ void GraphGroup::load(const OptimizerBase::ScatterStateFunc& scatterFn) {
|
||||
*/
|
||||
if(!options_->get<bool>("no-reload")) {
|
||||
std::string modelFileName = options_->get<std::string>("model");
|
||||
bool foundModel = false;
|
||||
|
||||
if(filesystem::exists(modelFileName)) {
|
||||
// these are structures that get fill in the main process and then broadcasted to other MPI
|
||||
std::vector<io::Item> items;
|
||||
bool markReloaded = true;
|
||||
|
||||
if(isMainProcess()) {
|
||||
if(filesystem::exists(modelFileName)) {
|
||||
LOG(info, "Loading model from {}", modelFileName);
|
||||
foundModel = true;
|
||||
items = io::loadItems(modelFileName);
|
||||
markReloaded = true;
|
||||
} else if(options_->hasAndNotEmpty("pretrained-model")) {
|
||||
std::string pretrainedModelFileName = options_->get<std::string>("pretrained-model");
|
||||
LOG(info, "[training] Initializing model weights with pre-trained model {}", pretrainedModelFileName);
|
||||
foundModel = true;
|
||||
items = io::loadItems(pretrainedModelFileName);
|
||||
markReloaded = false;
|
||||
}
|
||||
}
|
||||
|
||||
// if a model file exists, the main process will find it and propagate this information to other MPI nodes
|
||||
if(mpi_)
|
||||
mpi_->bCast(&foundModel, 1, mpi_->getDataType(&foundModel));
|
||||
|
||||
if(foundModel) {
|
||||
// continue with checkpoint loading
|
||||
if(mpi_) {
|
||||
// broadcast model information to other processes
|
||||
mpi_->bCast(items);
|
||||
mpi_->bCast(&markReloaded, 1, mpi_->getDataType(&markReloaded));
|
||||
}
|
||||
|
||||
// handles MPI
|
||||
if(scheduler_)
|
||||
scheduler_->load(modelFileName);
|
||||
|
||||
// we just load it N times from disk (it'll be in disk cache after the first)
|
||||
// this also allocates memory correctly when calling forward() inside restoreFromCheckPoint
|
||||
size_t i = 0;
|
||||
for(auto graph : graphs_)
|
||||
models_[i++]->load(graph, modelFileName);
|
||||
models_[i++]->load(graph, items, markReloaded);
|
||||
|
||||
// try to restore everything from checkpoint now
|
||||
restoreFromCheckpoint(modelFileName, scatterFn);
|
||||
} else if(options_->hasAndNotEmpty("pretrained-model")) {
|
||||
std::string nameInit = options_->get<std::string>("pretrained-model");
|
||||
LOG(info, "[training] Initializing model weights with pre-trained model {}", nameInit);
|
||||
|
||||
size_t i = 0;
|
||||
for(auto graph : graphs_)
|
||||
models_[i++]->load(graph, nameInit, false);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -316,19 +342,26 @@ bool GraphGroup::restoreFromCheckpoint(const std::string& modelFileName,
|
||||
|
||||
std::string checkpointName = modelFileName + ".optimizer.npz"; // @TODO: change to .checkpoint.npz, would break backwards compat
|
||||
|
||||
if(!filesystem::exists(checkpointName)) {
|
||||
// if a checkpoint exists, the main process will find it and propagate this information to other MPI nodes
|
||||
bool foundCheckpoint = filesystem::exists(checkpointName);
|
||||
if(mpi_)
|
||||
mpi_->bCast(&foundCheckpoint, 1, mpi_->getDataType(&foundCheckpoint));
|
||||
|
||||
// all nodes will either continue or exit
|
||||
if(!foundCheckpoint) {
|
||||
LOG(warn, "No checkpoint found, parameters reloaded from last inference model");
|
||||
return false; // failed to restore
|
||||
}
|
||||
|
||||
auto items = io::loadItems(checkpointName);
|
||||
|
||||
// make sure all nodes see the same checkpoint data, may not be the case with distributed file systems
|
||||
// when there was a delay in updating the caches accross nodes. So here node 0 sends its data to all.
|
||||
// We still load them all from disk, but that serves more as a trick to allocate the correct memory.
|
||||
if(mpi_)
|
||||
for(auto& item : items)
|
||||
mpi_->bCast(item);
|
||||
std::vector<marian::io::Item> items;
|
||||
// make sure all nodes receive the same checkpoint data from the main process.
|
||||
if(mpi_) { // only the main process loads the checkpoint and the rest receives a copy
|
||||
if(isMainProcess())
|
||||
items = io::loadItems(checkpointName);
|
||||
mpi_->bCast(items);
|
||||
} else { // not doing MPI, so just load the checkpoint from disk
|
||||
items = io::loadItems(checkpointName);
|
||||
}
|
||||
|
||||
// @TODO: probably we want to have the list of DeviceIds as an attribute
|
||||
std::vector<Ptr<Backend>> backends;
|
||||
@ -351,7 +384,8 @@ bool GraphGroup::restoreFromCheckpoint(const std::string& modelFileName,
|
||||
// run a full forward pass over the paramters to allocate the parameters values in order (by parameter name).
|
||||
// Just doing graph->params()->allocateForward() is not sufficient.
|
||||
ABORT_IF(graph->params()->vals()->shape() != masterParameters.shape,
|
||||
"Graph parameter sizes and master copy parameter sizes in checkpoint do not match");
|
||||
"Graph parameter sizes and master copy parameter sizes in checkpoint do not match ({} != {})",
|
||||
graph->params()->vals()->shape(), masterParameters.shape);
|
||||
|
||||
// Convert type of io::Item to match graph parameter type.
|
||||
if(masterParameters.type != graph->params()->vals()->type())
|
||||
|
@ -478,24 +478,11 @@ public:
|
||||
state_->samplesDisp = 0;
|
||||
state_->wordsDisp = 0;
|
||||
}
|
||||
|
||||
// progress heartbeat for MS-internal Philly compute cluster
|
||||
// This environment variable exists when running on the cluster.
|
||||
using namespace std::chrono;
|
||||
if((!mpi_ || mpi_->myMPIRank() == 0) && getenv("PHILLY_JOB_ID")
|
||||
&& heartBeatTimer_.elapsed<std::chrono::minutes>() >= 30) {
|
||||
fprintf(stderr, "PROGRESS: %.2f%%\nEVALERR: %.7f%%\n",
|
||||
(double)calculateLogicalEpoch(),
|
||||
state_->costSum / (state_->costCount ? state_->costCount : 1));
|
||||
fflush(stderr);
|
||||
heartBeatTimer_.start();
|
||||
}
|
||||
}
|
||||
|
||||
void load(const std::string& name) {
|
||||
std::string nameYaml = name + ".progress.yml";
|
||||
if(filesystem::exists(nameYaml))
|
||||
state_->load(nameYaml);
|
||||
void loadFromString(const std::string yamlString) {
|
||||
if(!yamlString.empty())
|
||||
state_->loadFromString(yamlString);
|
||||
|
||||
if(options_->get<bool>("no-restore-corpus")) {
|
||||
state_->samplesEpoch = 0;
|
||||
@ -519,6 +506,19 @@ public:
|
||||
state_->newLoad();
|
||||
}
|
||||
|
||||
void load(const std::string& name) {
|
||||
std::string nameYaml = name + ".progress.yml";
|
||||
std::string yamlStr;
|
||||
if(mpi_->isMainProcess())
|
||||
if(filesystem::exists(nameYaml))
|
||||
yamlStr = io::InputFileStream(nameYaml).readToString();
|
||||
|
||||
if(mpi_)
|
||||
mpi_->bCast(yamlStr);
|
||||
|
||||
loadFromString(yamlStr);
|
||||
}
|
||||
|
||||
void save(const std::string& name) {
|
||||
// Save config options
|
||||
std::ofstream fout(name + ".yml");
|
||||
|
@ -1,11 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "common/definitions.h"
|
||||
#include "common/file_stream.h"
|
||||
#include "common/filesystem.h"
|
||||
#include "common/scheduling_parameter.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
namespace marian {
|
||||
@ -194,11 +195,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void load(const std::string& name) {
|
||||
if(!filesystem::exists(name))
|
||||
return;
|
||||
|
||||
YAML::Node config = YAML::LoadFile(name);
|
||||
void loadFromString(const std::string& yamlString) {
|
||||
YAML::Node config = YAML::Load(yamlString);
|
||||
|
||||
epochs = config["epochs"].as<size_t>();
|
||||
batches = config["batches"].as<size_t>();
|
||||
@ -242,6 +240,13 @@ public:
|
||||
seedCorpus = config["seed-corpus"].as<std::string>();
|
||||
}
|
||||
|
||||
void load(const std::string& name) {
|
||||
if(!filesystem::exists(name))
|
||||
return;
|
||||
|
||||
loadFromString(io::InputFileStream(name).readToString());
|
||||
}
|
||||
|
||||
void save(const std::string& name) const {
|
||||
std::ofstream fout(name);
|
||||
YAML::Node config;
|
||||
|
@ -165,15 +165,6 @@ public:
|
||||
// abort early to avoid potentially costly batching and translation before error message
|
||||
ABORT_IF(statFreq.unit != SchedulingUnit::updates, "Units other than 'u' are not supported for --stat-freq value {}", statFreq);
|
||||
|
||||
// Override display for progress heartbeat for MS-internal Philly compute cluster
|
||||
// otherwise this job may be killed prematurely if no log for 4 hrs
|
||||
if(getenv("PHILLY_JOB_ID")) { // this environment variable exists when running on the cluster
|
||||
if(statFreq.n == 0) {
|
||||
statFreq.n = 10000;
|
||||
statFreq.unit = SchedulingUnit::updates;
|
||||
}
|
||||
}
|
||||
|
||||
bool doNbest = options_->get<bool>("n-best");
|
||||
|
||||
bg.prepare();
|
||||
|
Loading…
Reference in New Issue
Block a user