Merged PR 21648: Allow for dynamic gradient scaling to fade out after N updates

Allow for dynamic gradient scaling to fade out after N updates
This commit is contained in:
Marcin Junczys-Dowmunt 2021-12-06 23:20:44 +00:00
parent bbc673c50f
commit e8ea37cd5b
3 changed files with 20 additions and 4 deletions

View File

@ -562,7 +562,11 @@ void ProdBatchedLegacy(marian::Tensor C,
ProdBatchedTypedLegacy<float, float>(C, allocator, A, B, transA, transB, beta, scalar);
#if COMPILE_FP16
} else if(C->type() == Type::float16) { // not a *.cu file
ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
// we use computeType=float here for fp16 training as this seems more stable and roughly as fast
ProdBatchedTypedLegacy<half, float>(C, allocator, A, B, transA, transB, beta, scalar);
// original for reference:
// ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
#endif
} else {
ABORT("ProdBatchedLegacy not implemented for element type {}", C->type());

View File

@ -31,11 +31,16 @@ GraphGroup::GraphGroup(Ptr<Options> options, Ptr<IMPIWrapper> mpi)
if(vgc.size() > 0) dynamicGradientScalingFactor_ = std::stof(vgc[0]);
if(vgc.size() > 1) dynamicGradientScalingUseLogs_ = vgc[1] == "log";
if(vgc.size() > 2) dynamicGradientScalingFadeout_ = std::stoul(vgc[2]);
LOG_ONCE(info,
"Re-scaling gradient to have average gradient norm if (log={}) gradient norm diverges from average by {} sigmas",
dynamicGradientScalingUseLogs_,
dynamicGradientScalingFactor_);
if(dynamicGradientScalingFadeout_ > 0)
LOG_ONCE(info,
"Dynamic gradient re-scaling will fade out linearly after {} updates",
dynamicGradientScalingFadeout_);
}
if(options_->get<bool>("check-gradient-nan")) {
@ -229,11 +234,17 @@ float GraphGroup::computeNormalizationFactor(float gNorm, size_t updateTrgWords)
auto deltaTransform = gNormTransform - gNormAvgTransform; // compute the difference between the current transformer gradient norm and the running average.
auto gNormStdTransform = std::sqrt(gNormVarTransform); // compute STD for the running average of (log) gradient norms.
float fadeoutMultiplier = 1.f;
if(dynamicGradientScalingFadeout_ > 0ul) // fade out linearly after that many updates @TODO: allow units other than updates
fadeoutMultiplier = (float)std::max(dynamicGradientScalingFadeout_, scheduler_->numberOfBatches()) / (float)dynamicGradientScalingFadeout_;
float dynamicGradientScalingFactorWithFadeout = dynamicGradientScalingFactor_ * fadeoutMultiplier; // if fadeoutMultiplier increases dynamic gradient scaling becomes less and less likely to happen over time.
// delta of (log) gradient norm vs (log) gradient norm average is larger than N standard deviations
// hence rescale gradient using the average.
if(scheduler_->numberOfBatches() >= window && deltaTransform > dynamicGradientScalingFactor_ * gNormStdTransform) {
LOG(debug, "log gradient norms: {} :: {:.4f} - {:.4f} = {:.4f} > {:.4f} * {:.4f}",
dynamicGradientScalingUseLogs_, gNormTransform, gNormAvgTransform, deltaTransform, dynamicGradientScalingFactor_, gNormStdTransform);
if(scheduler_->numberOfBatches() >= window && deltaTransform > dynamicGradientScalingFactorWithFadeout * gNormStdTransform) {
if(isMainProcess())
LOG(debug, "log gradient norms: {} :: {:.4f} - {:.4f} = {:.4f} > {:.4f} * {:.4f} - scaling gradient by {:.4f}",
dynamicGradientScalingUseLogs_, gNormTransform, gNormAvgTransform, deltaTransform, dynamicGradientScalingFactorWithFadeout, gNormStdTransform, gNormAvg / gNorm);
normalizationFactor *= gNorm / gNormAvg; // since we later do gradient / normalizationFactor this divides by norm and multiplies by the average, rescaling to the average.
}

View File

@ -74,6 +74,7 @@ protected:
bool dynamicGradientScaling_{false};
float dynamicGradientScalingFactor_{2.f};
bool dynamicGradientScalingUseLogs_{false};
size_t dynamicGradientScalingFadeout_{0ul};
// determines the number of input streams (i.e. input files or fields in the TSV input) that need
// to be included in the batch, i.e. without alignments and weights