Merged PR 23429: Small fixes around fp16 training and batch fitting

This PR introduces small fixes around fp16 training and batch fitting:
* Multi-loss casts type to first loss-type before accumulation (aborted before due to missing cast)
* Throw `ShapeSizeException` if total expanded shape size exceeds numeric capacity of the maximum int value (2^31-1)
* During mini-batch-fitting, catch `ShapeSizeException` and use another sizing hint. Aborts outside mini-batch-fitting.
* Negative `--workspace -N` value allocates workspace as total available GPU memory minus N megabytes.
This commit is contained in:
Marcin Junczys-Dowmunt 2022-04-11 20:19:58 +00:00
parent 1e4e1014ed
commit 1a74358277
14 changed files with 92 additions and 22 deletions

View File

@ -11,11 +11,15 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Added
### Fixed
- Multi-loss casts type to first loss-type before accumulation (aborted before due to missing cast)
- Throw `ShapeSizeException` if total expanded shape size exceeds numeric capacity of the maximum int value (2^31-1)
- During mini-batch-fitting, catch `ShapeSizeException` and use another sizing hint. Aborts outside mini-batch-fitting.
- Fix incorrect/missing gradient accumulation with delay > 1 or large effective batch size of biases of affine operations.
- Fixed case augmentation with multi-threaded reading.
- Scripts using PyYAML now use `safe_load`; see https://msg.pyyaml.org/load
### Changed
- Negative `--workspace -N` value allocates workspace as total available GPU memory minus N megabytes.
- Set default parameters for cost-scaling to 8.f 10000 1.f 8.f, i.e. when scaling scale by 8 and do not try to automatically scale up or down. This seems most stable.
- Make guided-alignment faster via sparse memory layout, add alignment points for EOS, remove losses other than ce.
- Changed minimal C++ standard to C++-17

View File

@ -1 +1 @@
v1.11.6
v1.11.7

View File

@ -118,8 +118,8 @@ void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
->implicit_val("basic");
cli.add<std::vector<std::string>>("--config,-c",
"Configuration file(s). If multiple, later overrides earlier");
cli.add<size_t>("--workspace,-w",
"Preallocate arg MB of work space",
cli.add<int>("--workspace,-w",
"Preallocate arg MB of work space. Negative `--workspace -N` value allocates workspace as total available GPU memory minus N megabytes.",
defaultWorkspace);
cli.add<std::string>("--log",
"Log training process information to file given by arg");

View File

@ -12,6 +12,26 @@
namespace marian {
class ShapeSizeException : public std::exception {
private:
char* message_;
public:
ShapeSizeException(size_t available, size_t asked) {
std::string mstr = "Expanded shape size " + std::to_string(asked)
+ " exceeds numeric capcacity " + std::to_string(available);
message_ = new char[mstr.size() + 1];
std::copy(mstr.begin(), mstr.end(), message_);
message_[mstr.size()] = 0;
}
~ShapeSizeException() { delete[] message_; }
virtual const char* what() const noexcept override { return message_; }
};
struct Slice // Python-like slice/index descriptor
{
Slice(int b, int e, int s) : begin(b), end(e), stride(s) {}
@ -110,10 +130,12 @@ public:
template<typename T = int> // using a template so that FactoredSegmenter, which uses this as well, can pass size_t
inline T elements() const {
T el = 1;
size_t el = 1;
for(auto s : shape_)
el *= (T)s;
return el;
el *= (size_t)s;
if(el > std::numeric_limits<T>::max())
throw ShapeSizeException(std::numeric_limits<T>::max(), el);
return (T)el;
}
inline void dims(int i, std::vector<int>& d) const {

View File

@ -84,7 +84,7 @@ public:
auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
graph->setDevice(device);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graph->reserveWorkspaceMB(options_->get<int>("workspace"));
graphs_.push_back(graph);
}

View File

@ -23,6 +23,23 @@ void ExpressionGraph::setDevice(DeviceId deviceId, Ptr<Device> device) {
}
}
void ExpressionGraph::reserveWorkspaceMB(int num) {
size_t bytes;
if(num > 0) {
bytes = (size_t)num * 1024 * 1024 - 1;
} else if (num < 0) {
ABORT_IF(getDeviceId().type == DeviceType::cpu, "Negative workspace not allowed on CPU device");
size_t globalMemorySize = backend_->getGlobalMemorySize(); // in bytes, only implemented for GPU backend
size_t notWorkspaceSize = (size_t)std::abs(num) * 1024 * 1024 - 1;
ABORT_IF(notWorkspaceSize >= globalMemorySize, "Negative workspace {} larger/equal total memory {}?", notWorkspaceSize, globalMemorySize);
bytes = globalMemorySize - notWorkspaceSize;
LOG(debug, "Reserving {} = {} - {} bytes as workspace", bytes, globalMemorySize, notWorkspaceSize);
} else {
ABORT("Allocating 0 bytes?");
}
tensors_->reserve(bytes);
}
Expr ExpressionGraph::add(Expr node) {
auto found = tensors_->findOrRemember(node);
if(found) {

View File

@ -244,11 +244,10 @@ public:
* Preallocate workspace memory (MB) for the graph.
* Sets the size of the memory available for the forward and backward step of the training procedure.
* This does not include model size and optimizer parameters that are allocated outsize workspace.
* If memory is negative (<0) the total available GPU memory is used with the absolute value substracted.
* Negative workspace is not supported on CPU.
*/
void reserveWorkspaceMB(size_t num) {
size_t bytes = num * 1024 * 1024 - 1;
tensors_->reserve(bytes);
}
void reserveWorkspaceMB(int num);
/** Copy tensor objects from one graph to current graph */
void reuseWorkspace(Ptr<ExpressionGraph> graph) {
@ -277,7 +276,7 @@ public:
tensors_->throwAtReallocation(true);
backprop();
tensors_->throwAtReallocation(false);
} catch(AllocationException&) {
} catch(const AllocationException&) {
tensors_->throwAtReallocation(false);
return false;
}

View File

@ -53,9 +53,9 @@ static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> graph,
auto attentionAtAligned = cols(flatten(attention), alignmentIndices);
float epsilon = 1e-6f;
Expr alignmentLoss = -sum(alignmentValues * log(attentionAtAligned + epsilon));
Expr alignmentLoss = -sum(cast(alignmentValues * log(attentionAtAligned + epsilon), Type::float32));
size_t numLabels = alignmentIndices->shape().elements();
// Create label node, also weigh by scalar so labels and cost are in the same domain.
// Fractional label counts are OK. But only if combined as "sum".
// @TODO: It is ugly to check the multi-loss type here, but doing this right requires

View File

@ -73,7 +73,7 @@ public:
auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
graph->setDevice(device);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graph->reserveWorkspaceMB(options_->get<int>("workspace"));
graphs_.push_back(graph);
}

View File

@ -29,6 +29,7 @@ public:
// for GPU only, calls cudaSetDevice, does nothing on CPU. Maybe change name.
virtual void setDevice() = 0;
virtual void synchronize() = 0;
virtual size_t getGlobalMemorySize() = 0;
// for CPU, sets to use optimized code for inference.
// for GPU, this is invalid. for gpu, isOptimized() function always returns false.

View File

@ -20,6 +20,10 @@ public:
void setDevice() override {}
void synchronize() override {}
size_t getGlobalMemorySize() override {
ABORT("Not implemented on CPU");
}
// for CPU & inference only, sets to use optimized code for inference. Does nothing for GPU.
void setOptimized(bool optimize) override { optimized_ = optimize; }
bool isOptimized() override { return optimized_; }

View File

@ -96,6 +96,12 @@ public:
CudaCompute getCudaComputeCapability() { return compute_; }
size_t getGlobalMemorySize() override {
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, (int)deviceId_.no));
return prop.totalGlobalMem;
}
private:
cublasHandle_t cublasHandle_{0}; // make sure it's 0, so it can be initalized lazily
cusparseHandle_t cusparseHandle_{0}; // as above

View File

@ -82,7 +82,7 @@ void GraphGroup::initGraphsAndOpts() {
graph->setDevice(device);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graph->reserveWorkspaceMB(options_->get<int>("workspace"));
graphs_.push_back(graph);
@ -510,8 +510,18 @@ Ptr<data::BatchStats> GraphGroup::collectStats(Ptr<ExpressionGraph> graph,
lengths[j] = std::min(lengths[j], localMaxes[j]);
auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_);
auto loss = model->build(graph, batch);
fits = graph->fits();
// We check for a ShapeSizeException (happens if total shape size would exceed max int).
// If caught, we reduce the batch size. In any other context, this exception will cause
// an error and exit Marian.
try {
auto loss = model->build(graph, batch);
fits = graph->fits();
} catch(const ShapeSizeException& e) {
LOG(debug, "Exception for maxBatch size {}: {}", maxBatch, e.what());
fits = false;
}
if(fits)
maxBatch *= 2;
}
@ -530,8 +540,15 @@ Ptr<data::BatchStats> GraphGroup::collectStats(Ptr<ExpressionGraph> graph,
do {
size_t current = (start + end) / 2;
auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_);
auto loss = model->build(graph, batch);
fits = graph->fits();
// Same as above.
try {
auto loss = model->build(graph, batch);
fits = graph->fits();
} catch(const ShapeSizeException& e) {
LOG(debug, "Exception for maxBatch size {}: {}", maxBatch, e.what());
fits = false;
}
LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits);

View File

@ -98,7 +98,7 @@ public:
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
graph->getBackend()->setQuantizeRange(options_->get<float>("quantize-range"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graph->reserveWorkspaceMB(options_->get<int>("workspace"));
graphs_[id] = graph;
std::vector<Ptr<Scorer>> scorers;
@ -311,7 +311,7 @@ public:
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
graph->getBackend()->setQuantizeRange(options_->get<float>("quantize-range"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graph->reserveWorkspaceMB(options_->get<int>("workspace"));
graphs_.push_back(graph);
auto scorers = createScorers(options_, model_items_);