mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 11929: Move around code to make later comparison with FP16 code easier
This does not introduce any new functionality, just moves code around, so that future PRs are easier to compare. Moving old GraphGroup code to training/deprecated. Once it is clear there is nothing in there that's worth saving, this will be deleted. Replace -Ofast with -O3 and make sure ffinite-math is turned off.
This commit is contained in:
parent
69d6f02711
commit
f1be95fce4
@ -167,9 +167,9 @@ else(MSVC)
|
||||
endif(CMAKE_COMPILER_IS_GNUCC)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "-std=c++11 -pthread ${CMAKE_GCC_FLAGS} -fPIC ${DISABLE_GLOBALLY} -march=${BUILD_ARCH} ${INTRINSICS}")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -m64 -funroll-loops -ffinite-math-only -g ${CMAKE_RDYNAMIC_FLAG}")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "-O3 -m64 -funroll-loops -g ${CMAKE_RDYNAMIC_FLAG}")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g ${CMAKE_RDYNAMIC_FLAG}")
|
||||
set(CMAKE_CXX_FLAGS_SLIM "-Ofast -m64 -funroll-loops -ffinite-math-only -DNDEBUG")
|
||||
set(CMAKE_CXX_FLAGS_SLIM "-O3 -m64 -funroll-loops -DNDEBUG")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE}")
|
||||
set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -pg")
|
||||
set(CMAKE_CXX_FLAGS_PROFGEN "${CMAKE_CXX_FLAGS_RELEASE} -fprofile-generate -fprofile-correction")
|
||||
@ -177,9 +177,9 @@ else(MSVC)
|
||||
|
||||
# these need to be set separately
|
||||
set(CMAKE_C_FLAGS "-pthread ${CMAKE_GCC_FLAGS} -fPIC ${DISABLE_GLOBALLY} -march=${BUILD_ARCH} ${INTRINSICS}")
|
||||
set(CMAKE_C_FLAGS_RELEASE "-O3 -m64 -funroll-loops -ffinite-math-only -g ${CMAKE_RDYNAMIC_FLAG}")
|
||||
set(CMAKE_C_FLAGS_RELEASE "-O3 -m64 -funroll-loops -g ${CMAKE_RDYNAMIC_FLAG}")
|
||||
set(CMAKE_C_FLAGS_DEBUG "-O0 -g ${CMAKE_RDYNAMIC_FLAG}")
|
||||
set(CMAKE_C_FLAGS_SLIM "-O3 -m64 -funroll-loops -ffinite-math-only -DNDEBUG")
|
||||
set(CMAKE_C_FLAGS_SLIM "-O3 -m64 -funroll-loops -DNDEBUG")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELEASE}")
|
||||
set(CMAKE_C_FLAGS_PROFILE "${CMAKE_C_FLAGS_RELEASE} -pg")
|
||||
set(CMAKE_C_FLAGS_PROFGEN "${CMAKE_C_FLAGS_RELEASE} -fprofile-generate -fprofile-correction")
|
||||
|
@ -85,11 +85,9 @@ add_library(marian STATIC
|
||||
translator/scorers.cpp
|
||||
|
||||
training/graph_group_async.cpp
|
||||
training/graph_group_async_drop.cpp
|
||||
training/graph_group_sync.cpp
|
||||
training/graph_group.cpp
|
||||
training/graph_group_singleton.cpp
|
||||
training/graph_group_multinode.cpp
|
||||
training/graph_group_multinode_sync.cpp
|
||||
training/validator.cpp
|
||||
training/communicator.cpp
|
||||
training/scheduler.cpp
|
||||
@ -145,8 +143,6 @@ cuda_add_library(marian_cuda
|
||||
tensors/gpu/cudnn_wrappers.cu
|
||||
translator/nth_element.cu
|
||||
translator/helpers.cu
|
||||
training/gradient_dropping/gpu/dropper.cu
|
||||
training/gradient_dropping/gpu/sparse_algorithm.cu
|
||||
STATIC)
|
||||
|
||||
target_compile_options(marian_cuda PUBLIC ${ALL_WARNINGS})
|
||||
|
@ -2,16 +2,10 @@
|
||||
#include "marian.h"
|
||||
|
||||
#include "training/graph_group_async.h"
|
||||
#include "training/graph_group_multinode_sync.h"
|
||||
#include "training/graph_group_singleton.h"
|
||||
#include "training/graph_group_sync.h"
|
||||
#include "training/training.h"
|
||||
|
||||
#ifdef CUDA_FOUND
|
||||
#include "training/graph_group_async_drop.h"
|
||||
#include "training/graph_group_multinode.h"
|
||||
#endif
|
||||
|
||||
#include "3rd_party/ExceptionWithCallStack.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
@ -27,18 +21,7 @@ int main(int argc, char** argv) {
|
||||
// MultiNodeGraphGroupSync.
|
||||
if(options->get<bool>("multi-node")) {
|
||||
LOG(warn, "[experimental] Using old multi-node training implementations that are not up-to-date");
|
||||
|
||||
if(options->get<bool>("sync-sgd")) {
|
||||
LOG(info, "Using multi-node synchronous training");
|
||||
New<Train<MultiNodeGraphGroupSync>>(options)->run();
|
||||
} else {
|
||||
#ifdef CUDA_FOUND
|
||||
LOG(info, "Using multi-node asynchronous training");
|
||||
New<Train<MultiNodeGraphGroup>>(options)->run();
|
||||
#else
|
||||
ABORT("Asynchronous multi-node training requires CUDA");
|
||||
#endif
|
||||
}
|
||||
ABORT("Old multi-node training code disabled");
|
||||
}
|
||||
// --sync-sgd always selects SyncGraphGroup
|
||||
//
|
||||
@ -46,7 +29,7 @@ int main(int argc, char** argv) {
|
||||
// processes x (single, multiple) GPUs per MPI process. This variant is presently up-to-date and
|
||||
// best supported.
|
||||
else if (options->get<bool>("sync-sgd")) {
|
||||
LOG(info, "Using synchronous training");
|
||||
LOG(info, "Using synchronous SGD");
|
||||
New<Train<SyncGraphGroup>>(options)->run();
|
||||
}
|
||||
else {
|
||||
@ -55,17 +38,8 @@ int main(int argc, char** argv) {
|
||||
LOG(info, "Using single-device training");
|
||||
New<Train<SingletonGraph>>(options)->run();
|
||||
} else {
|
||||
if(options->get<float>("grad-dropping-rate") > 0.0) {
|
||||
#ifdef CUDA_FOUND
|
||||
LOG(info, "Using asynchronous training with gradient dropping");
|
||||
New<Train<AsyncGraphGroupDrop>>(options)->run();
|
||||
#else
|
||||
ABORT("Asynchronous training with gradient dropping requires CUDA");
|
||||
#endif
|
||||
} else {
|
||||
LOG(info, "Using asynchronous training");
|
||||
New<Train<AsyncGraphGroup>>(options)->run();
|
||||
}
|
||||
LOG(info, "Using asynchronous training");
|
||||
New<Train<AsyncGraphGroup>>(options)->run();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -123,7 +123,7 @@ struct FbgemmPacked16PackNodeOp : public UnaryNodeOp {
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
};
|
||||
;
|
||||
|
||||
// Pack a matrix (int8) into cache utilization efficient way (block format) together with quantization into int8
|
||||
// PackMatrix packMat_: the type of packed matrix - A or B matrix
|
||||
// marian::Type packType_: the type the input matrix is packed - packed8avx2 or packed8avx512
|
||||
@ -132,7 +132,6 @@ struct FbgemmPacked16PackNodeOp : public UnaryNodeOp {
|
||||
// int ncol_: the number of columns
|
||||
// uint64_t packsize_: the size of the packed matrix
|
||||
// (the size of int8 packed B from fbgemm:PackAWithQuantRowOffset + quantization scale, offset and zero point)
|
||||
|
||||
struct FbgemmPacked8PackNodeOp : public UnaryNodeOp {
|
||||
PackMatrix packMat_;
|
||||
marian::Type packType_;
|
||||
|
89
src/training/graph_group.cpp
Normal file
89
src/training/graph_group.cpp
Normal file
@ -0,0 +1,89 @@
|
||||
#include "training/graph_group.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
GraphGroup::GraphGroup(Ptr<Options> options) : options_(options), opt_(Optimizer(options)) {}
|
||||
|
||||
void GraphGroup::validate() {
|
||||
ABORT_IF(finalized_, "Training has already finished.");
|
||||
}
|
||||
|
||||
void GraphGroup::finalize() {
|
||||
finalized_ = true;
|
||||
}
|
||||
|
||||
Ptr<data::BatchStats> GraphGroup::collectStats(Ptr<ExpressionGraph> graph,
|
||||
Ptr<models::ICriterionFunction> model,
|
||||
const std::vector<Ptr<Vocab>>& vocabs,
|
||||
double multiplier) {
|
||||
auto stats = New<data::BatchStats>();
|
||||
|
||||
size_t numFiles = options_->get<std::vector<std::string>>("train-sets").size();
|
||||
|
||||
// Initialize first batch to step size
|
||||
size_t first = options_->get<size_t>("mini-batch-fit-step");
|
||||
|
||||
// Increase batch size and sentence length by this step size
|
||||
size_t step = options_->get<size_t>("mini-batch-fit-step");
|
||||
|
||||
size_t maxLength = options_->get<size_t>("max-length");
|
||||
maxLength = (size_t)(std::ceil(maxLength / (float)step) * step);
|
||||
|
||||
// this should be only one class label per line on input, hence restricting length to 1
|
||||
std::vector<size_t> localMaxes(numFiles, maxLength);
|
||||
auto inputTypes = options_->get<std::vector<std::string>>("input-types", {});
|
||||
for(int i = 0; i < inputTypes.size(); ++i)
|
||||
if(inputTypes[i] == "class")
|
||||
localMaxes[i] = 1;
|
||||
|
||||
size_t maxBatch = 512;
|
||||
bool fits = true;
|
||||
while(fits) {
|
||||
std::vector<size_t> lengths(numFiles, first);
|
||||
for(int j = 0; j < lengths.size(); ++j) // apply length restrictions
|
||||
lengths[j] = std::min(lengths[j], localMaxes[j]);
|
||||
|
||||
auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_);
|
||||
auto cost = model->build(graph, batch);
|
||||
fits = graph->fits();
|
||||
if(fits)
|
||||
maxBatch *= 2;
|
||||
}
|
||||
|
||||
// Do a binary search for maxmimum batch size that fits into given workspace memory
|
||||
// for a tested sentence length.
|
||||
for(size_t i = step; i <= maxLength; i += step) {
|
||||
size_t start = 1;
|
||||
size_t end = maxBatch;
|
||||
|
||||
std::vector<size_t> lengths(numFiles, i);
|
||||
for(int j = 0; j < lengths.size(); ++j) // apply length restrictions
|
||||
lengths[j] = std::min(lengths[j], localMaxes[j]);
|
||||
fits = true;
|
||||
|
||||
do {
|
||||
size_t current = (start + end) / 2;
|
||||
auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_);
|
||||
auto cost = model->build(graph, batch);
|
||||
fits = graph->fits();
|
||||
|
||||
LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits);
|
||||
|
||||
if(fits) {
|
||||
stats->add(batch, multiplier);
|
||||
start = current + 1;
|
||||
} else {
|
||||
end = current - 1;
|
||||
}
|
||||
} while(end - start > step);
|
||||
|
||||
maxBatch = start;
|
||||
}
|
||||
return stats;
|
||||
}
|
||||
|
||||
void GraphGroup::setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling
|
||||
typicalTrgBatchWords_ = typicalTrgBatchWords;
|
||||
}
|
||||
|
||||
}
|
@ -19,12 +19,14 @@ class GraphGroup {
|
||||
protected:
|
||||
Ptr<Options> options_;
|
||||
Ptr<OptimizerBase> opt_; // the optimizer
|
||||
|
||||
Ptr<Scheduler> scheduler_; // scheduler that keeps track of how much has been processed
|
||||
|
||||
bool finalized_{false}; // 'true' if training has completed (further updates are no longer allowed)
|
||||
size_t typicalTrgBatchWords_{ 0 }; // for dynamic batch sizing: typical batch size in words
|
||||
|
||||
public:
|
||||
GraphGroup(Ptr<Options> options) : options_(options), opt_(Optimizer(options)) {}
|
||||
GraphGroup(Ptr<Options> options);
|
||||
|
||||
virtual ~GraphGroup() {}
|
||||
|
||||
@ -34,13 +36,9 @@ public:
|
||||
|
||||
virtual void save(bool isFinal = false) = 0;
|
||||
|
||||
void validate() {
|
||||
ABORT_IF(finalized_, "Training has already finished.");
|
||||
}
|
||||
void validate();
|
||||
|
||||
virtual void finalize() {
|
||||
finalized_ = true;
|
||||
}
|
||||
virtual void finalize();
|
||||
|
||||
virtual void setScheduler(Ptr<Scheduler> scheduler) = 0;
|
||||
|
||||
@ -57,158 +55,9 @@ public:
|
||||
Ptr<data::BatchStats> collectStats(Ptr<ExpressionGraph> graph,
|
||||
Ptr<models::ICriterionFunction> model,
|
||||
const std::vector<Ptr<Vocab>>& vocabs,
|
||||
double multiplier = 1.) {
|
||||
auto stats = New<data::BatchStats>();
|
||||
double multiplier = 1.);
|
||||
|
||||
size_t numFiles = options_->get<std::vector<std::string>>("train-sets").size();
|
||||
|
||||
// Initialize first batch to step size
|
||||
size_t first = options_->get<size_t>("mini-batch-fit-step");
|
||||
|
||||
// Increase batch size and sentence length by this step size
|
||||
size_t step = options_->get<size_t>("mini-batch-fit-step");
|
||||
|
||||
size_t maxLength = options_->get<size_t>("max-length");
|
||||
maxLength = (size_t)(std::ceil(maxLength / (float)step) * step);
|
||||
|
||||
// this should be only one class label per line on input, hence restricting length to 1
|
||||
std::vector<size_t> localMaxes(numFiles, maxLength);
|
||||
auto inputTypes = options_->get<std::vector<std::string>>("input-types", {});
|
||||
for(int i = 0; i < inputTypes.size(); ++i)
|
||||
if(inputTypes[i] == "class")
|
||||
localMaxes[i] = 1;
|
||||
|
||||
size_t maxBatch = 512;
|
||||
bool fits = true;
|
||||
while(fits) {
|
||||
std::vector<size_t> lengths(numFiles, first);
|
||||
for(int j = 0; j < lengths.size(); ++j) // apply length restrictions
|
||||
lengths[j] = std::min(lengths[j], localMaxes[j]);
|
||||
|
||||
auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_);
|
||||
auto cost = model->build(graph, batch);
|
||||
fits = graph->fits();
|
||||
if(fits)
|
||||
maxBatch *= 2;
|
||||
}
|
||||
|
||||
// Do a binary search for maxmimum batch size that fits into given workspace memory
|
||||
// for a tested sentence length.
|
||||
for(size_t i = step; i <= maxLength; i += step) {
|
||||
size_t start = 1;
|
||||
size_t end = maxBatch;
|
||||
|
||||
std::vector<size_t> lengths(numFiles, i);
|
||||
for(int j = 0; j < lengths.size(); ++j) // apply length restrictions
|
||||
lengths[j] = std::min(lengths[j], localMaxes[j]);
|
||||
fits = true;
|
||||
|
||||
do {
|
||||
size_t current = (start + end) / 2;
|
||||
auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_);
|
||||
auto cost = model->build(graph, batch);
|
||||
fits = graph->fits();
|
||||
|
||||
LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits);
|
||||
|
||||
if(fits) {
|
||||
stats->add(batch, multiplier);
|
||||
start = current + 1;
|
||||
} else {
|
||||
end = current - 1;
|
||||
}
|
||||
} while(end - start > step);
|
||||
|
||||
maxBatch = start;
|
||||
}
|
||||
return stats;
|
||||
}
|
||||
|
||||
void setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling
|
||||
typicalTrgBatchWords_ = typicalTrgBatchWords;
|
||||
}
|
||||
void setTypicalTrgBatchWords(size_t typicalTrgBatchWords);
|
||||
};
|
||||
|
||||
/**
|
||||
* Base class for multi-node versions of GraphGroups.
|
||||
*/
|
||||
class MultiNodeGraphGroupBase : public GraphGroup {
|
||||
using Base = GraphGroup;
|
||||
|
||||
protected:
|
||||
Ptr<IMPIWrapper> mpi_; // all MPI-like communication goes through this
|
||||
|
||||
/** Devices (GPUs) on this node. */
|
||||
std::vector<size_t> devices_; // [num local GPUs]
|
||||
|
||||
/** Graph builders for clients (which run forward and backward passes). */
|
||||
std::vector<Ptr<models::ICriterionFunction>> clientBuilders_;
|
||||
|
||||
/** Graphs of clients. One entry per GPU on this node. */
|
||||
std::vector<Ptr<ExpressionGraph>> clientGraphs_; // [num local GPUs]
|
||||
|
||||
public:
|
||||
MultiNodeGraphGroupBase(Ptr<Options> options, Ptr<IMPIWrapper> mpi)
|
||||
: Base(options), mpi_(mpi) {
|
||||
|
||||
// Set up devices for this node
|
||||
std::vector<size_t> devices; // set of GPU device ids for this MPI process
|
||||
for (auto& d : Config::getDevices(options_))
|
||||
devices.push_back(d.no);
|
||||
loadDeviceConfig(devices); // set up numberClientsOfNodes_[] and devices_[]
|
||||
|
||||
// Create builders and graphs for clients; that is, for each GPU we use on this node.
|
||||
for (size_t i = 0; i < devices_.size(); i++) {
|
||||
clientGraphs_.push_back(New<ExpressionGraph>());
|
||||
clientGraphs_[i]->setDevice({ devices_[i], DeviceType::gpu });
|
||||
clientGraphs_[i]->reserveWorkspaceMB(options_->get<size_t>("workspace"));
|
||||
clientBuilders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the GPU configuration of this node (i.e. which GPUs to use) and the
|
||||
* number of GPUs on the other nodes.
|
||||
*/
|
||||
// deviceConfig has this format
|
||||
// - for each node
|
||||
// - number of GPUs on that node
|
||||
// - GPU ids for that node
|
||||
// e.g. 0:0 1 1: 2 3 -> (2, (0, 1)) (2, (2,3))
|
||||
void loadDeviceConfig(std::vector<size_t> deviceConfig) {
|
||||
// parse device config array
|
||||
size_t index = 0; // cursor for next()
|
||||
auto next = [&]() { // helper function to get the next item
|
||||
ABORT_IF(index == deviceConfig.size(), "mal-formed device config array??");
|
||||
return deviceConfig[index++];
|
||||
};
|
||||
std::vector<std::vector<size_t>> allDevices(mpi_->numMPIProcesses());
|
||||
for (auto& devices : allDevices) {
|
||||
devices.resize(next());
|
||||
for (auto& device : devices)
|
||||
device = next();
|
||||
}
|
||||
ABORT_IF(index != deviceConfig.size(), "mal-formed device config array??");
|
||||
|
||||
// validate
|
||||
ABORT_IF(allDevices.front().size() == 0, "no devices specified??");
|
||||
for (auto& devices : allDevices) {
|
||||
ABORT_IF(devices.size() != allDevices.front().size(), "all MPI nodes must use the same number of devices");
|
||||
}
|
||||
|
||||
// get our own config
|
||||
devices_ = allDevices[mpi_->myMPIRank()];
|
||||
|
||||
// log
|
||||
LOG(info, "[mpi rank {}] device configuration", mpi_->myMPIRank());
|
||||
for (auto& device : devices_)
|
||||
LOG(info, "[mpi rank {}] - {}", mpi_->myMPIRank(), device);
|
||||
}
|
||||
|
||||
virtual void finalize() override {
|
||||
if (mpi_)
|
||||
finalizeMPI(std::move(mpi_));
|
||||
Base::finalize();
|
||||
}
|
||||
};
|
||||
} // namespace marian
|
||||
|
Loading…
Reference in New Issue
Block a user