Merged PR 25686: Loading checkpoints from main node only via MPI

Enables loading of model checkpoints from main node only via MPI.

Until now the checkpoint needed to present in the same location on all nodes. That could be done either via writing to a shared filesystem (problematic due to bad syncing) or by manual copying to the same local location, e.g. /tmp on each node (while writing only happened to one main location).

Now, marian can resume training from only one location on the main node. The remaining nodes do not need to have access. E.g. local /tmp on the main node can be used, or race conditons on shared storage are avoided.

Also avoids creating files for logging on more than one node. This is a bit wonky, done via environment variable lookup.
This commit is contained in:
Marcin Junczys-Dowmunt 2022-09-21 20:39:54 +00:00
parent 76964791ad
commit 7d2045a907
22 changed files with 232 additions and 119 deletions

View File

@ -9,8 +9,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- `--force-decode` option for marian-decoder
- `--output-sampling` now works with ensembles (requires proper normalization via e.g `--weights 0.5 0.5`)
### Fixed
- Read/restore checkpoints from main process only when training with MPI
- Multi-loss casts type to first loss-type before accumulation (aborted before due to missing cast)
- Throw `ShapeSizeException` if total expanded shape size exceeds numeric capacity of the maximum int value (2^31-1)
- During mini-batch-fitting, catch `ShapeSizeException` and use another sizing hint. Aborts outside mini-batch-fitting.

View File

@ -1 +1 @@
v1.11.7
v1.11.9

View File

@ -19,24 +19,6 @@ static inline std::string interpolateEnvVars(std::string str) {
return str;
}
#if 1
if(getenv("PHILLY_JOB_ID")) {
const char* cluster = getenv("PHILLY_CLUSTER");
const char* vc = getenv("PHILLY_VC");
// this environment variable exists when running on the cluster
if(cluster && vc) {
static const std::string s_gfsPrefix
= std::string("/gfs/") + cluster + "/" + vc + "/";
static const std::string s_hdfsPrefix
= std::string("/hdfs/") + cluster + "/" + vc + "/";
if(str.find(s_gfsPrefix) == 0)
str = std::string("/hdfs/") + vc + "/" + str.substr(s_gfsPrefix.size());
else if(str.find(s_hdfsPrefix) == 0)
str = std::string("/hdfs/") + vc + "/"
+ str.substr(s_hdfsPrefix.size());
}
}
#endif
for(;;) {
const auto pos = str.find("${");
if(pos == std::string::npos)

View File

@ -74,6 +74,12 @@ std::string InputFileStream::getFileName() const {
return file_.string();
}
std::string InputFileStream::readToString() const {
std::stringstream ss;
ss << this->rdbuf();
return ss.str();
}
// wrapper around std::getline() that handles Windows input files with extra CR
// chars at the line end
std::istream &getline(std::istream &in, std::string &line) {
@ -85,6 +91,7 @@ std::istream &getline(std::istream &in, std::string &line) {
line.pop_back();
return in;
}
///////////////////////////////////////////////////////////////////////////////////////////////
OutputFileStream::OutputFileStream(const std::string &file)
: std::ostream(NULL), file_(file) {
@ -119,7 +126,7 @@ TemporaryFile::TemporaryFile(const std::string &base, bool earlyUnlink)
NormalizeTempPrefix(baseTemp);
MakeTemp(baseTemp);
inSteam_ = UPtr<io::InputFileStream>(new io::InputFileStream(file_.string()));
inStream_ = UPtr<io::InputFileStream>(new io::InputFileStream(file_.string()));
if(unlink_) {
ABORT_IF(remove(file_.string().c_str()), "Error while deleting '{}'", file_.string());
}
@ -190,7 +197,7 @@ void TemporaryFile::MakeTemp(const std::string &base) {
}
UPtr<InputFileStream> TemporaryFile::getInputStream() {
return std::move(inSteam_);
return std::move(inStream_);
}
std::string TemporaryFile::getFileName() const {

View File

@ -46,6 +46,7 @@ public:
bool empty();
void setbufsize(size_t size);
std::string getFileName() const;
std::string readToString() const;
protected:
marian::filesystem::Path file_;
@ -92,7 +93,7 @@ public:
protected:
bool unlink_;
UPtr<InputFileStream> inSteam_;
UPtr<InputFileStream> inStream_;
void NormalizeTempPrefix(std::string& base) const;
void MakeTemp(const std::string& base);

View File

@ -1,5 +1,7 @@
#include "logging.h"
#include "common/config.h"
#include "common/utils.h"
#include "spdlog/sinks/null_sink.h"
#include "3rd_party/ExceptionWithCallStack.h"
#include <time.h>
@ -30,9 +32,14 @@ std::shared_ptr<spdlog::logger> createStderrLogger(const std::string& name,
if(!quiet)
sinks.push_back(stderr_sink);
for(auto&& file : files) {
auto file_sink = std::make_shared<spdlog::sinks::simple_file_sink_st>(file, true);
sinks.push_back(file_sink);
// @TODO: think how to solve this better than using OMPI_COMM_WORLD_RANK env variable
// only create output files if we are the main process or if MPI rank is not defined
int rank = marian::utils::getMPIRankEnv(); // this function looks up OMPI_COMM_WORLD_RANK env variable
if(rank == 0) {
for(auto&& file : files) {
auto file_sink = std::make_shared<spdlog::sinks::simple_file_sink_st>(file, true);
sinks.push_back(file_sink);
}
}
auto logger = std::make_shared<spdlog::logger>(name, begin(sinks), end(sinks));

View File

@ -180,7 +180,8 @@ std::string exec(const std::string& cmd, const std::vector<std::string>& args /*
std::pair<std::string, int> hostnameAndProcessId() { // helper to get hostname:pid
#ifdef _WIN32
std::string hostname = getenv("COMPUTERNAME");
const char* res = getenv("COMPUTERNAME");
std::string hostname = res ? std::string(res) : "";
auto processId = (int)GetCurrentProcessId();
#else
static std::string hostname = []() { // not sure if gethostname() is expensive. This way we call it only once.
@ -193,6 +194,15 @@ std::pair<std::string, int> hostnameAndProcessId() { // helper to get hostname:
return {hostname, processId};
}
// returns MPI rank from environment variable if set, otherwise 0
int getMPIRankEnv() {
const char* rank = getenv("OMPI_COMM_WORLD_RANK");
if(rank)
return std::atoi(rank);
else
return 0;
}
// format a long number with comma separators
std::string withCommas(size_t n) {
std::string res = std::to_string(n);

View File

@ -43,6 +43,9 @@ std::string exec(const std::string& cmd, const std::vector<std::string>& args =
std::pair<std::string, int> hostnameAndProcessId();
// returns MPI rank from environment variable if set, otherwise 0
int getMPIRankEnv();
std::string withCommas(size_t n);
bool beginsWith(const std::string& text, const std::string& prefix);
bool endsWith(const std::string& text, const std::string& suffix);

View File

@ -141,16 +141,6 @@ public:
output->Write((long)batch->getSentenceIds()[i],
sentVector);
}
// progress heartbeat for MS-internal Philly compute cluster
// otherwise this job may be killed prematurely if no log for 4 hrs
if (getenv("PHILLY_JOB_ID") // this environment variable exists when running on the cluster
&& id % 1000 == 0) // hard beat once every 1000 batches
{
auto progress = id / 10000.f; //fake progress for now, becomes >100 after 1M batches
fprintf(stderr, "PROGRESS: %.2f%%\n", progress);
fflush(stderr);
}
};
pool.enqueue(task, batchId++);

View File

@ -75,6 +75,10 @@ public:
return Logits(apply(graph, batch, inference_));
}
void load(Ptr<ExpressionGraph> /*graph*/, const std::vector<io::Item>& /*items*/, bool) override {
LOG(critical, "Loading MNIST model is not supported");
}
void load(Ptr<ExpressionGraph> /*graph*/, const std::string& /*name*/, bool) override {
LOG(critical, "Loading MNIST model is not supported");
}

View File

@ -217,6 +217,12 @@ public:
Ptr<IModel> getModel() { return model_; }
void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded) override {
model_->load(graph, items, markedReloaded);
}
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override {
@ -263,6 +269,12 @@ public:
Ptr<IModel> getModel() { return model_; }
virtual void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markReloaded = true) override {
model_->load(graph, items, markReloaded);
}
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override {

View File

@ -152,6 +152,12 @@ public:
void push_back(Ptr<EncoderBase> encoder) { encoders_.push_back(encoder); }
void push_back(Ptr<ClassifierBase> classifier) { classifiers_.push_back(classifier); }
void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded) override {
graph->load(items, markedReloaded && !opt<bool>("ignore-model-config", false));
}
void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded) override {

View File

@ -15,7 +15,8 @@ public:
virtual void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded = true) = 0;
bool markedReloaded = true) override
= 0;
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,

View File

@ -25,6 +25,11 @@ class EncoderPoolerBase : public models::IModel {
public:
virtual ~EncoderPoolerBase() {}
virtual void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded = true) override
= 0;
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override
@ -162,6 +167,12 @@ public:
void push_back(Ptr<EncoderBase> encoder) { encoders_.push_back(encoder); }
void push_back(Ptr<PoolerBase> pooler) { poolers_.push_back(pooler); }
void load(Ptr<ExpressionGraph> graph,
const std::vector<io::Item>& items,
bool markedReloaded) override {
graph->load(items, markedReloaded && !opt<bool>("ignore-model-config", false));
}
void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded) override {

View File

@ -2,6 +2,7 @@
#include <string>
#include "marian.h"
#include "common/io_item.h"
#include "layers/loss.h"
#include "layers/generic.h"
@ -24,6 +25,12 @@ public:
const std::string&,
bool markReloaded = true)
= 0;
virtual void load(Ptr<ExpressionGraph>,
const std::vector<io::Item>&,
bool markReloaded = true)
= 0;
virtual void save(Ptr<ExpressionGraph>,
const std::string&,
bool saveTranslatorConfig = false)
@ -47,6 +54,12 @@ public:
const std::string&,
bool markReloaded = true)
= 0;
virtual void load(Ptr<ExpressionGraph>,
const std::vector<io::Item>&,
bool markReloaded = true)
= 0;
virtual void save(Ptr<ExpressionGraph>,
const std::string&,
bool saveTranslatorConfig = false)

View File

@ -201,16 +201,6 @@ public:
}
}
}
// progress heartbeat for MS-internal Philly compute cluster
// otherwise this job may be killed prematurely if no log for 4 hrs
if (getenv("PHILLY_JOB_ID") // this environment variable exists when running on the cluster
&& id % 1000 == 0) // hard beat once every 1000 batches
{
auto progress = id / 10000.f; //fake progress for now, becomes >100 after 1M batches
fprintf(stdout, "PROGRESS: %.2f%%\n", progress);
fflush(stdout);
}
};
pool.enqueue(task, batchId++);

View File

@ -134,11 +134,12 @@ public:
// get the limit for int count
size_t limit = (size_t)std::numeric_limits<int>::max();
size_t remaining = count, offset = 0;
size_t remaining = count;
size_t offset = 0;
// while there are elements that we have not sent yet, loop until all has been sent in chunks of at most `limit`.
while(remaining > 0) {
int intCount = (int)std::min(remaining, limit);
int intCount = (int)std::min(remaining, limit);
HANDLE_MPI_ERROR(MPI_Bcast((char*)buf + offset * (size_t)datatypeSize, intCount, datatype, (int)rootRank, comm));
offset += (size_t)intCount;
remaining -= (size_t)intCount;
@ -193,6 +194,49 @@ public:
virtual void finalize() override {
HANDLE_MPI_ERROR(MPI_Finalize());
}
virtual void bCast(io::Item& item, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
if(isMainProcess())
ABORT_IF(item.bytes.empty(), "Broadcasting empty item via MPI should not happen. Please report.");
unsigned long long bytesLen = item.bytes.size();
bCast(&bytesLen, 1, getDataType(&bytesLen), rootRank, comm);
item.bytes.resize(bytesLen);
bCast(item.bytes.data(), bytesLen, getDataType(item.bytes.data()), rootRank, comm);
unsigned long long shapeLen = item.shape.size();
bCast(&shapeLen, 1, getDataType(&shapeLen), rootRank, comm);
item.shape.resize(shapeLen);
bCast(item.shape.data(), shapeLen, getDataType(item.shape.data()), rootRank, comm);
bCast(item.name, rootRank, comm);
size_t type = (size_t)item.type;
bCast(&type, 1, getDataType(&type), rootRank, comm);
item.type = (Type)type;
}
virtual void bCast(std::vector<io::Item>& items, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
size_t numItems = 0;
if(isMainProcess())
numItems = items.size();
bCast(&numItems, 1, getDataType(&numItems), rootRank, comm);
items.resize(numItems);
for(auto& item : items)
bCast(item, rootRank, comm);
}
virtual void bCast(std::string& str, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
size_t length = 0;
if(isMainProcess())
length = str.size();
bCast(&length, 1, getDataType(&length), rootRank, comm);
str.resize(length);
bCast(str.data(), length, getDataType(str.data()), rootRank, comm);
}
};
#endif
@ -232,6 +276,19 @@ public:
// to only accept one parameter, and remove this error check can be removed.
ABORT_IF(sendbuf != recvbuf, "FakeMPIWrapper::allReduce() only implemented for in-place operation"); // otherwise it's not a no-op, we must copy data
}
virtual void bCast(io::Item& item, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
item; rootRank; comm;
}
virtual void bCast(std::vector<io::Item>& items, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
items; rootRank; comm;
}
virtual void bCast(std::string& str, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const override {
str; rootRank; comm;
}
#pragma warning(pop)
virtual void finalize() override { }
};

View File

@ -68,7 +68,7 @@ public:
#if MPI_FOUND
#else
enum MPI_Comm { MPI_COMM_WORLD };
enum MPI_Datatype { MPI_FLOAT, MPI_UNSIGNED_LONG_LONG, MPI_UNSIGNED_LONG, MPI_BYTE, MPI_INT };
enum MPI_Datatype { MPI_FLOAT, MPI_UNSIGNED_LONG_LONG, MPI_UNSIGNED_LONG, MPI_BYTE, MPI_INT, MPI_CXX_BOOL };
enum MPI_Op { MPI_SUM };
struct MPI_Status { int MPI_SOURCE; };
#define MPI_ANY_SOURCE ((size_t)-2)
@ -88,30 +88,16 @@ struct/*interface*/ IMPIWrapper {
static const size_t RECV_ANY_SOURCE = (size_t)MPI_ANY_SOURCE;
static MPI_Datatype getDataType(const char*) { return MPI_BYTE; }
static MPI_Datatype getDataType(const bool*) { return MPI_CXX_BOOL; }
static MPI_Datatype getDataType(const int*) { return MPI_INT; }
static MPI_Datatype getDataType(const float*) { return MPI_FLOAT; }
static MPI_Datatype getDataType(const unsigned long*) { return MPI_UNSIGNED_LONG; }
static MPI_Datatype getDataType(const unsigned long long*) { return MPI_UNSIGNED_LONG_LONG; }
void bCast(io::Item& item, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) {
ABORT_IF(item.bytes.empty(), "Broadcasting empty item via MPI??");
unsigned long long bytesLen = item.bytes.size();
bCast(&bytesLen, 1, getDataType(&bytesLen), rootRank, comm);
item.bytes.resize(bytesLen);
bCast(item.bytes.data(), item.bytes.size(), getDataType(item.bytes.data()), rootRank, comm);
unsigned long long shapeLen = item.shape.size();
bCast(&shapeLen, 1, getDataType(&shapeLen), rootRank, comm);
bCast(item.shape.data(), item.shape.size(), getDataType(item.shape.data()), rootRank, comm);
size_t type = (size_t)item.type;
bCast(&type, 1, getDataType(&type), rootRank, comm);
item.type = (Type)type;
}
virtual void bCast(io::Item& item, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const = 0;
virtual void bCast(std::vector<io::Item>& items, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const = 0;
virtual void bCast(std::string& str, size_t rootRank = 0, MPI_Comm comm = MPI_COMM_WORLD) const = 0;
std::string idStr() const;
};

View File

@ -283,25 +283,51 @@ void GraphGroup::load(const OptimizerBase::ScatterStateFunc& scatterFn) {
*/
if(!options_->get<bool>("no-reload")) {
std::string modelFileName = options_->get<std::string>("model");
bool foundModel = false;
if(filesystem::exists(modelFileName)) {
// these are structures that get fill in the main process and then broadcasted to other MPI
std::vector<io::Item> items;
bool markReloaded = true;
if(isMainProcess()) {
if(filesystem::exists(modelFileName)) {
LOG(info, "Loading model from {}", modelFileName);
foundModel = true;
items = io::loadItems(modelFileName);
markReloaded = true;
} else if(options_->hasAndNotEmpty("pretrained-model")) {
std::string pretrainedModelFileName = options_->get<std::string>("pretrained-model");
LOG(info, "[training] Initializing model weights with pre-trained model {}", pretrainedModelFileName);
foundModel = true;
items = io::loadItems(pretrainedModelFileName);
markReloaded = false;
}
}
// if a model file exists, the main process will find it and propagate this information to other MPI nodes
if(mpi_)
mpi_->bCast(&foundModel, 1, mpi_->getDataType(&foundModel));
if(foundModel) {
// continue with checkpoint loading
if(mpi_) {
// broadcast model information to other processes
mpi_->bCast(items);
mpi_->bCast(&markReloaded, 1, mpi_->getDataType(&markReloaded));
}
// handles MPI
if(scheduler_)
scheduler_->load(modelFileName);
// we just load it N times from disk (it'll be in disk cache after the first)
// this also allocates memory correctly when calling forward() inside restoreFromCheckPoint
size_t i = 0;
for(auto graph : graphs_)
models_[i++]->load(graph, modelFileName);
models_[i++]->load(graph, items, markReloaded);
// try to restore everything from checkpoint now
restoreFromCheckpoint(modelFileName, scatterFn);
} else if(options_->hasAndNotEmpty("pretrained-model")) {
std::string nameInit = options_->get<std::string>("pretrained-model");
LOG(info, "[training] Initializing model weights with pre-trained model {}", nameInit);
size_t i = 0;
for(auto graph : graphs_)
models_[i++]->load(graph, nameInit, false);
}
}
}
@ -316,19 +342,26 @@ bool GraphGroup::restoreFromCheckpoint(const std::string& modelFileName,
std::string checkpointName = modelFileName + ".optimizer.npz"; // @TODO: change to .checkpoint.npz, would break backwards compat
if(!filesystem::exists(checkpointName)) {
// if a checkpoint exists, the main process will find it and propagate this information to other MPI nodes
bool foundCheckpoint = filesystem::exists(checkpointName);
if(mpi_)
mpi_->bCast(&foundCheckpoint, 1, mpi_->getDataType(&foundCheckpoint));
// all nodes will either continue or exit
if(!foundCheckpoint) {
LOG(warn, "No checkpoint found, parameters reloaded from last inference model");
return false; // failed to restore
}
auto items = io::loadItems(checkpointName);
// make sure all nodes see the same checkpoint data, may not be the case with distributed file systems
// when there was a delay in updating the caches accross nodes. So here node 0 sends its data to all.
// We still load them all from disk, but that serves more as a trick to allocate the correct memory.
if(mpi_)
for(auto& item : items)
mpi_->bCast(item);
std::vector<marian::io::Item> items;
// make sure all nodes receive the same checkpoint data from the main process.
if(mpi_) { // only the main process loads the checkpoint and the rest receives a copy
if(isMainProcess())
items = io::loadItems(checkpointName);
mpi_->bCast(items);
} else { // not doing MPI, so just load the checkpoint from disk
items = io::loadItems(checkpointName);
}
// @TODO: probably we want to have the list of DeviceIds as an attribute
std::vector<Ptr<Backend>> backends;
@ -351,7 +384,8 @@ bool GraphGroup::restoreFromCheckpoint(const std::string& modelFileName,
// run a full forward pass over the paramters to allocate the parameters values in order (by parameter name).
// Just doing graph->params()->allocateForward() is not sufficient.
ABORT_IF(graph->params()->vals()->shape() != masterParameters.shape,
"Graph parameter sizes and master copy parameter sizes in checkpoint do not match");
"Graph parameter sizes and master copy parameter sizes in checkpoint do not match ({} != {})",
graph->params()->vals()->shape(), masterParameters.shape);
// Convert type of io::Item to match graph parameter type.
if(masterParameters.type != graph->params()->vals()->type())

View File

@ -478,24 +478,11 @@ public:
state_->samplesDisp = 0;
state_->wordsDisp = 0;
}
// progress heartbeat for MS-internal Philly compute cluster
// This environment variable exists when running on the cluster.
using namespace std::chrono;
if((!mpi_ || mpi_->myMPIRank() == 0) && getenv("PHILLY_JOB_ID")
&& heartBeatTimer_.elapsed<std::chrono::minutes>() >= 30) {
fprintf(stderr, "PROGRESS: %.2f%%\nEVALERR: %.7f%%\n",
(double)calculateLogicalEpoch(),
state_->costSum / (state_->costCount ? state_->costCount : 1));
fflush(stderr);
heartBeatTimer_.start();
}
}
void load(const std::string& name) {
std::string nameYaml = name + ".progress.yml";
if(filesystem::exists(nameYaml))
state_->load(nameYaml);
void loadFromString(const std::string yamlString) {
if(!yamlString.empty())
state_->loadFromString(yamlString);
if(options_->get<bool>("no-restore-corpus")) {
state_->samplesEpoch = 0;
@ -519,6 +506,19 @@ public:
state_->newLoad();
}
void load(const std::string& name) {
std::string nameYaml = name + ".progress.yml";
std::string yamlStr;
if(mpi_->isMainProcess())
if(filesystem::exists(nameYaml))
yamlStr = io::InputFileStream(nameYaml).readToString();
if(mpi_)
mpi_->bCast(yamlStr);
loadFromString(yamlStr);
}
void save(const std::string& name) {
// Save config options
std::ofstream fout(name + ".yml");

View File

@ -1,11 +1,12 @@
#pragma once
#include "common/definitions.h"
#include "common/file_stream.h"
#include "common/filesystem.h"
#include "common/scheduling_parameter.h"
#include "common/utils.h"
#include <fstream>
#include <sstream>
#include <vector>
namespace marian {
@ -194,11 +195,8 @@ public:
}
}
void load(const std::string& name) {
if(!filesystem::exists(name))
return;
YAML::Node config = YAML::LoadFile(name);
void loadFromString(const std::string& yamlString) {
YAML::Node config = YAML::Load(yamlString);
epochs = config["epochs"].as<size_t>();
batches = config["batches"].as<size_t>();
@ -242,6 +240,13 @@ public:
seedCorpus = config["seed-corpus"].as<std::string>();
}
void load(const std::string& name) {
if(!filesystem::exists(name))
return;
loadFromString(io::InputFileStream(name).readToString());
}
void save(const std::string& name) const {
std::ofstream fout(name);
YAML::Node config;

View File

@ -165,15 +165,6 @@ public:
// abort early to avoid potentially costly batching and translation before error message
ABORT_IF(statFreq.unit != SchedulingUnit::updates, "Units other than 'u' are not supported for --stat-freq value {}", statFreq);
// Override display for progress heartbeat for MS-internal Philly compute cluster
// otherwise this job may be killed prematurely if no log for 4 hrs
if(getenv("PHILLY_JOB_ID")) { // this environment variable exists when running on the cluster
if(statFreq.n == 0) {
statFreq.n = 10000;
statFreq.unit = SchedulingUnit::updates;
}
}
bool doNbest = options_->get<bool>("n-best");
bg.prepare();