mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 21648: Allow for dynamic gradient scaling to fade out after N updates
Allow for dynamic gradient scaling to fade out after N updates
This commit is contained in:
parent
bbc673c50f
commit
e8ea37cd5b
@ -562,7 +562,11 @@ void ProdBatchedLegacy(marian::Tensor C,
|
||||
ProdBatchedTypedLegacy<float, float>(C, allocator, A, B, transA, transB, beta, scalar);
|
||||
#if COMPILE_FP16
|
||||
} else if(C->type() == Type::float16) { // not a *.cu file
|
||||
ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
|
||||
// we use computeType=float here for fp16 training as this seems more stable and roughly as fast
|
||||
ProdBatchedTypedLegacy<half, float>(C, allocator, A, B, transA, transB, beta, scalar);
|
||||
|
||||
// original for reference:
|
||||
// ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
|
||||
#endif
|
||||
} else {
|
||||
ABORT("ProdBatchedLegacy not implemented for element type {}", C->type());
|
||||
|
@ -31,11 +31,16 @@ GraphGroup::GraphGroup(Ptr<Options> options, Ptr<IMPIWrapper> mpi)
|
||||
|
||||
if(vgc.size() > 0) dynamicGradientScalingFactor_ = std::stof(vgc[0]);
|
||||
if(vgc.size() > 1) dynamicGradientScalingUseLogs_ = vgc[1] == "log";
|
||||
if(vgc.size() > 2) dynamicGradientScalingFadeout_ = std::stoul(vgc[2]);
|
||||
|
||||
LOG_ONCE(info,
|
||||
"Re-scaling gradient to have average gradient norm if (log={}) gradient norm diverges from average by {} sigmas",
|
||||
dynamicGradientScalingUseLogs_,
|
||||
dynamicGradientScalingFactor_);
|
||||
if(dynamicGradientScalingFadeout_ > 0)
|
||||
LOG_ONCE(info,
|
||||
"Dynamic gradient re-scaling will fade out linearly after {} updates",
|
||||
dynamicGradientScalingFadeout_);
|
||||
}
|
||||
|
||||
if(options_->get<bool>("check-gradient-nan")) {
|
||||
@ -229,11 +234,17 @@ float GraphGroup::computeNormalizationFactor(float gNorm, size_t updateTrgWords)
|
||||
auto deltaTransform = gNormTransform - gNormAvgTransform; // compute the difference between the current transformer gradient norm and the running average.
|
||||
auto gNormStdTransform = std::sqrt(gNormVarTransform); // compute STD for the running average of (log) gradient norms.
|
||||
|
||||
float fadeoutMultiplier = 1.f;
|
||||
if(dynamicGradientScalingFadeout_ > 0ul) // fade out linearly after that many updates @TODO: allow units other than updates
|
||||
fadeoutMultiplier = (float)std::max(dynamicGradientScalingFadeout_, scheduler_->numberOfBatches()) / (float)dynamicGradientScalingFadeout_;
|
||||
|
||||
float dynamicGradientScalingFactorWithFadeout = dynamicGradientScalingFactor_ * fadeoutMultiplier; // if fadeoutMultiplier increases dynamic gradient scaling becomes less and less likely to happen over time.
|
||||
// delta of (log) gradient norm vs (log) gradient norm average is larger than N standard deviations
|
||||
// hence rescale gradient using the average.
|
||||
if(scheduler_->numberOfBatches() >= window && deltaTransform > dynamicGradientScalingFactor_ * gNormStdTransform) {
|
||||
LOG(debug, "log gradient norms: {} :: {:.4f} - {:.4f} = {:.4f} > {:.4f} * {:.4f}",
|
||||
dynamicGradientScalingUseLogs_, gNormTransform, gNormAvgTransform, deltaTransform, dynamicGradientScalingFactor_, gNormStdTransform);
|
||||
if(scheduler_->numberOfBatches() >= window && deltaTransform > dynamicGradientScalingFactorWithFadeout * gNormStdTransform) {
|
||||
if(isMainProcess())
|
||||
LOG(debug, "log gradient norms: {} :: {:.4f} - {:.4f} = {:.4f} > {:.4f} * {:.4f} - scaling gradient by {:.4f}",
|
||||
dynamicGradientScalingUseLogs_, gNormTransform, gNormAvgTransform, deltaTransform, dynamicGradientScalingFactorWithFadeout, gNormStdTransform, gNormAvg / gNorm);
|
||||
|
||||
normalizationFactor *= gNorm / gNormAvg; // since we later do gradient / normalizationFactor this divides by norm and multiplies by the average, rescaling to the average.
|
||||
}
|
||||
|
@ -74,6 +74,7 @@ protected:
|
||||
bool dynamicGradientScaling_{false};
|
||||
float dynamicGradientScalingFactor_{2.f};
|
||||
bool dynamicGradientScalingUseLogs_{false};
|
||||
size_t dynamicGradientScalingFadeout_{0ul};
|
||||
|
||||
// determines the number of input streams (i.e. input files or fields in the TSV input) that need
|
||||
// to be included in the batch, i.e. without alignments and weights
|
||||
|
Loading…
Reference in New Issue
Block a user