Merged PR 18505: RMSNorm on GPU

Support for RMSNorm as drop-in replace for LayerNorm from _Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization_. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`.
This commit is contained in:
Martin Junczys-Dowmunt 2021-04-10 15:28:38 +00:00
parent a05124176d
commit caddad90cd
12 changed files with 638 additions and 23 deletions

View File

@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- Support for RMSNorm as drop-in replace for LayerNorm from `Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization`. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`.
- Extend suppression of unwanted output symbols, specifically "\n" from default vocabulary if generated by SentencePiece with byte-fallback. Deactivates with --allow-special
- Allow for fine-grained CPU intrinsics overrides when BUILD_ARCH != native e.g. -DBUILD_ARCH=x86-64 -DCOMPILE_AVX512=off
- Adds custom bias epilogue kernel.

View File

@ -1 +1 @@
v1.10.16
v1.10.17

View File

@ -208,8 +208,15 @@ void ExpressionGraph::backward(bool reset, float clipValue) {
}
if(v->trainable() && v->marked_for_debug()) {
LOG(info, "Debug Grad: {} op={}", v->debug_message(), v->type());
LOG(info, v->grad()->debug());
Logger log = spdlog::get("general");
if(log) {
LOG(info, "Debug Grad: {} op={}", v->debug_message(), v->type());
LOG(info, v->grad()->debug());
}
else {
std::cerr << "Debug Grad: " << v->debug_message() << " op=" << v->type() << std::endl;
std::cerr << v->grad()->debug() << std::endl;
}
}
if(v->trainable() && clipValue != 0) {

View File

@ -749,6 +749,18 @@ Expr layerNorm(Expr x,
return Expression<LayerNormalizationOp>(nodes, eps);
}
Expr rmsNorm(Expr x,
Expr gamma,
Expr beta /*= nullptr*/,
float eps /*= 1e-9*/) {
// layerNorm accumulates in float, so small eps is fine
std::vector<Expr> nodes = {x, gamma};
if(beta)
nodes.push_back(beta);
return Expression<RMSNormalizationOp>(nodes, eps);
}
Expr highway(Expr y, Expr x, Expr t) {
std::vector<Expr> nodes = {y, x, t};
return Expression<HighwayNodeOp>(nodes);

View File

@ -915,6 +915,18 @@ Expr weighted_average(Expr in, Expr weights, int ax = 0);
*/
Expr layerNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);
/**
* Applies RMS normalization over the last dimension.
*
* See: Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization.
* In Advances in Neural Information Processing Systems 32. Vancouver, Canada.
* @f[
\frac{x}{\sqrt{\frac{1}{N}\sum x^2 + \mathrm{eps}}} \times \gamma + \beta
* @f]
* @see RMSNormalizationOp
*/
Expr rmsNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);
/**
* Highway transformation.
* Computes the highway tranform on @p y and @p x as gated by @p t:

View File

@ -1369,6 +1369,64 @@ private:
float eps_;
};
// RMS norm along last axis
struct RMSNormalizationOp : public NaryNodeOp {
public:
RMSNormalizationOp(const std::vector<Expr>& nodes, float eps = 1e-9)
: NaryNodeOp(nodes), eps_(eps) {
// @TODO: dimension check
}
NodeOps forwardOps() override {
return {NodeOp(
RMSNormalization(val_,
child(0)->val(),
child(1)->val(),
(children_.size() == 3) ? child(2)->val() : nullptr,
eps_))};
}
// @BUGBUG: backward has not been tested for broadcasting gamma/beta
NodeOps backwardOps() override {
return {NodeOp(
RMSNormalizationGrad(
graph()->allocator(),
child(0)->grad(),
child(1)->grad(),
(children_.size() == 3) ? child(2)->grad() : nullptr,
adj_,
val_,
child(0)->val(),
child(1)->val(),
(children_.size() == 3) ? child(2)->val() : nullptr,
eps_))};
}
const std::string type() override { return "rms_normalization"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, eps_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<RMSNormalizationOp>(node);
if(!cnode)
return false;
if(eps_ != cnode->eps_)
return false;
return true;
}
private:
friend class SerializationHelpers; // @TODO: use the same name for this as SqrtNodeOp
float eps_;
};
struct HighwayNodeOp : public NaryNodeOp {
HighwayNodeOp(const std::vector<Expr>& nodes) : NaryNodeOp(nodes) {}

View File

@ -212,4 +212,10 @@ static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = st
return marian::layerNorm(x, scale, bias, 1e-6f);
}
static inline Expr rmsNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
int dimModel = x->shape()[-1];
auto scale = x->graph()->param(prefix + "_rms_scale" + suffix, {1, dimModel}, inits::ones());
return marian::rmsNorm(x, scale, nullptr, 1e-6f);
}
} // namespace marian

View File

@ -176,6 +176,8 @@ public:
// layer normalization
else if (op == 'n')
output = layerNorm(output, prefix, "_pre");
else if (op == 'r')
output = rmsNorm(output, prefix, "_pre");
else
ABORT("Unknown pre-processing operation '{}'", op);
}
@ -201,6 +203,8 @@ public:
// layer normalization
else if(op == 'n')
output = layerNorm(output, prefix);
else if(op == 'r')
output = rmsNorm(output, prefix);
else
ABORT("Unknown pre-processing operation '{}'", op);
}

View File

@ -977,7 +977,7 @@ float L2Norm(Tensor in, Ptr<Allocator> /*not used*/) {
float sum = 0.f;
size_t size = in->size();
const float* data = in->data();
#pragma omp parallel for simd reduction(+ : sum)
#pragma omp parallel for simd reduction(+ : sum)
for(size_t i = 0; i < size; ++i) {
sum += data[i] * data[i];
}
@ -998,14 +998,14 @@ void Att(Tensor out_, Tensor va_, Tensor context_, Tensor state_) {
int rows = m;
int cols = k;
#pragma omp parallel for
#pragma omp parallel for
for(int j = 0; j < rows; ++j) {
const float* vaRow = va;
const float* ctxRow = ctx + (j % (b * t)) * cols;
const float* stateRow = state + ((j / (b * t)) * b + j % b) * cols;
float sum = 0.f;
#pragma omp simd reduction(+ : sum)
#pragma omp simd reduction(+ : sum)
for(int i = 0; i < cols; ++i) {
float z = ctxRow[i] + stateRow[i];
sum += std::tanh(z) * vaRow[i];
@ -1035,7 +1035,7 @@ void AttBack(Tensor gVa_,
size_t k = context_->shape()[-1];
size_t n = context_->shape()[-2];
#pragma omp parallel for reduction(+ : gState[:n * k], gVa[:k])
#pragma omp parallel for reduction(+ : gState[:n * k], gVa[:k])
for(size_t j = 0; j < m; ++j) {
float* gcRow = gContext + j * k;
float* gsRow = gState + (j % n) * k;
@ -1045,7 +1045,7 @@ void AttBack(Tensor gVa_,
float adj_j = adj[j];
#pragma omp simd
#pragma omp simd
for(size_t i = 0; i < k; ++i) {
float z = cRow[i] + sRow[i];
@ -1070,20 +1070,20 @@ void LayerNormalizationImpl(float* out,
float eps,
int rows,
int cols) {
#pragma omp parallel for
#pragma omp parallel for
for(int j = 0; j < rows; ++j) {
float* so = out + j * cols;
const float* sp = in + j * cols;
float sum = 0.f;
#pragma omp simd reduction(+ : sum)
#pragma omp simd reduction(+ : sum)
for(int i = 0; i < cols; ++i) {
sum += sp[i];
}
float mean = sum / cols;
float sqSum = 0.f;
#pragma omp simd reduction(+ : sqSum)
#pragma omp simd reduction(+ : sqSum)
for(int i = 0; i < cols; ++i) {
float ex = sp[i] - mean;
sqSum += ex * ex;
@ -1091,7 +1091,7 @@ void LayerNormalizationImpl(float* out,
float sigma = std::sqrt(sqSum / cols + eps);
#pragma omp simd
#pragma omp simd
for(int i = 0; i < cols; ++i) {
float t = alpha[alphaStride * i] * ((sp[i] - mean) / sigma);
if(hasBeta)
@ -1168,7 +1168,7 @@ void LayerNormalizationGrad(Tensor gradX_,
size_t cols = y_->shape()[-1];
if(beta) {
#pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols])
#pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols])
for(size_t j = 0; j < rows; ++j) {
const float* xRow = x + j * cols;
const float* yRow = y + j * cols;
@ -1180,7 +1180,7 @@ void LayerNormalizationGrad(Tensor gradX_,
float sum_adj_x = 0.f;
float sum_sqr = 0.f;
#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
for(size_t i = 0; i < cols; ++i) {
sum_x += xRow[i];
sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[betaStride * i] : 0.f)) / gamma[gammaStride * i];
@ -1188,14 +1188,14 @@ void LayerNormalizationGrad(Tensor gradX_,
}
float mean = sum_x / cols;
#pragma omp simd reduction(+ : sum_sqr)
#pragma omp simd reduction(+ : sum_sqr)
for(size_t i = 0; i < cols; ++i) {
float ex = xRow[i] - mean;
sum_sqr += ex * ex;
}
float sigma = std::sqrt(sum_sqr / cols + eps);
#pragma omp simd
#pragma omp simd
for(size_t i = 0; i < cols; ++i) {
float grad_x = 0.f;
float x_hat = (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i];
@ -1209,8 +1209,8 @@ void LayerNormalizationGrad(Tensor gradX_,
gradBeta[betaStride * i] += adjRow[i];
}
}
} else {
#pragma omp parallel for reduction(+ : gradGamma[:cols])
} else { // @TODO: this code duplication is really ugly, but required for omp to work correctly?
#pragma omp parallel for reduction(+ : gradGamma[:cols])
for(size_t j = 0; j < rows; ++j) {
const float* xRow = x + j * cols;
const float* yRow = y + j * cols;
@ -1222,23 +1222,22 @@ void LayerNormalizationGrad(Tensor gradX_,
float sum_adj_x = 0.f;
float sum_sqr = 0.f;
#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
for(size_t i = 0; i < cols; ++i) {
sum_x += xRow[i];
sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[betaStride * i] : 0.f)) / gamma[gammaStride * i];
// @TODO: beta is NULL here ^^
sum_adj_x += adjRow[i] * yRow[i] / gamma[gammaStride * i];
sum_adj += adjRow[i];
}
float mean = sum_x / cols;
#pragma omp simd reduction(+ : sum_sqr)
#pragma omp simd reduction(+ : sum_sqr)
for(size_t i = 0; i < cols; ++i) {
float ex = xRow[i] - mean;
sum_sqr += ex * ex;
}
float sigma = std::sqrt(sum_sqr / cols + eps);
#pragma omp simd
#pragma omp simd
for(size_t i = 0; i < cols; ++i) {
float grad_x = 0.f;
float x_hat = yRow[i] / gamma[gammaStride * i];
@ -1255,6 +1254,163 @@ void LayerNormalizationGrad(Tensor gradX_,
}
MARIAN_FFAST_MATH_END
MARIAN_FFAST_MATH_BEGIN
template <int alphaStride, int betaStride, bool hasBeta>
void RMSNormalizationImpl(float* out,
const float* in,
const float* alpha,
const float* beta,
float eps,
int rows,
int cols) {
#pragma omp parallel for
for(int j = 0; j < rows; ++j) {
float* so = out + j * cols;
const float* sp = in + j * cols;
float sqSum = 0.f;
#pragma omp simd reduction(+ : sqSum)
for(int i = 0; i < cols; ++i) {
sqSum += sp[i] * sp[i];
}
float rms = std::sqrt(sqSum / cols + eps);
#pragma omp simd
for(int i = 0; i < cols; ++i) {
float t = alpha[alphaStride * i] * (sp[i] / rms);
if(hasBeta)
t += beta[betaStride * i];
so[i] = t;
}
}
}
MARIAN_FFAST_MATH_END
template <int alphaStride>
inline void RMSNormalizationDispatchBeta(float* out,
const float* in,
const float* alpha,
Tensor beta,
float eps,
int rows,
int cols) {
if (beta) {
if (beta->shape().back() > 1) {
RMSNormalizationImpl<alphaStride, 1, true>(out, in, alpha, beta->data(), eps, rows, cols);
} else {
RMSNormalizationImpl<alphaStride, 0, true>(out, in, alpha, beta->data(), eps, rows, cols);
}
} else {
RMSNormalizationImpl<alphaStride, 0, false>(out, in, alpha, nullptr, eps, rows, cols);
}
}
void RMSNormalization(Tensor out,
Tensor in,
Tensor gamma,
Tensor beta,
float eps) {
const float* alpha = gamma->data();
const int alphaStride = gamma->shape().back() > 1; // broadcasting for alpha and beta
int rows = in->shape().elements() / in->shape().back();
int cols = in->shape().back();
if (alphaStride == 0) {
RMSNormalizationDispatchBeta<0>(out->data(), in->data(), alpha, beta, eps, rows, cols);
} else {
RMSNormalizationDispatchBeta<1>(out->data(), in->data(), alpha, beta, eps, rows, cols);
}
}
MARIAN_FFAST_MATH_BEGIN
void RMSNormalizationGrad(Tensor gradX_,
Tensor gradGamma_,
Tensor gradBeta_,
Tensor adj_,
Tensor y_,
Tensor x_,
Tensor gamma_,
Tensor beta_,
float eps) {
float* gradX = gradX_->data();
float* gradGamma = gradGamma_->data();
float* gradBeta = gradBeta_ ? gradBeta_->data() : nullptr;
float* adj = adj_->data();
float* x = x_->data();
float* y = y_->data();
float* gamma = gamma_->data();
float* beta = beta_ ? beta_->data() : nullptr;
// @TODO: The CPU implementation supports scalar gamma and beta. This is a left-over,
// we should enable that in the GPU version as well.
const int gammaStride = gamma_->shape().back() > 1; // broadcasting for alpha and beta. 0 means it's a scalar
const int betaStride = beta_ && beta_->shape().back() > 1;
size_t rows = y_->shape().elements() / y_->shape()[-1];
size_t cols = y_->shape()[-1];
if(beta) {
#pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols])
for(size_t j = 0; j < rows; ++j) {
const float* xRow = x + j * cols;
const float* yRow = y + j * cols;
const float* adjRow = adj + j * cols;
float* gradXRow = gradX + j * cols;
float sum_adj_r = 0.f;
float sum_sqr = 0.f;
#pragma omp simd reduction(+ : sum_adj_r, sum_sqr)
for(size_t i = 0; i < cols; ++i) {
sum_adj_r += adjRow[i] * (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i];
sum_sqr += xRow[i] * xRow[i];
}
float rms = std::sqrt(sum_sqr / cols + eps);
#pragma omp simd
for(size_t i = 0; i < cols; ++i) {
float rmsNorm = (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i];
float gradNorm = cols * adjRow[i] - rmsNorm * sum_adj_r;
gradNorm /= cols * rms;
gradXRow[i] += gamma[gammaStride * i] * gradNorm;
gradGamma[gammaStride * i] += adjRow[i] * rmsNorm;
gradBeta[betaStride * i] += adjRow[i];
}
}
} else {
#pragma omp parallel for reduction(+ : gradGamma[:cols])
for(size_t j = 0; j < rows; ++j) {
const float* xRow = x + j * cols;
const float* yRow = y + j * cols;
const float* adjRow = adj + j * cols;
float* gradXRow = gradX + j * cols;
float sum_adj_r = 0.f;
float sum_sqr = 0.f;
#pragma omp simd reduction(+ : sum_adj_r, sum_sqr)
for(size_t i = 0; i < cols; ++i) {
sum_adj_r += yRow[i] / gamma[gammaStride * i];
sum_sqr += xRow[i] * xRow[i];
}
float rms = std::sqrt(sum_sqr / cols + eps);
#pragma omp simd
for(size_t i = 0; i < cols; ++i) {
float rmsNorm = yRow[i] / gamma[gammaStride * i];
float gradNorm = cols * adjRow[i] - rmsNorm * sum_adj_r;
gradNorm /= cols * rms;
gradXRow[i] += gamma[gammaStride * i] * gradNorm;
gradGamma[gammaStride * i] += adjRow[i] * rmsNorm;
}
}
}
}
MARIAN_FFAST_MATH_END
void Shift(Tensor out_,
Tensor in_,
marian::Shape shift,

View File

@ -2303,6 +2303,273 @@ void LayerNormalizationGrad(Ptr<Allocator> allocator,
allocator->free(tempOnesMemory);
}
template <typename T, typename AccType = float>
__global__ void gRMSNormalization(T* out,
const T* in,
const T* gamma,
const T* beta,
int rows,
int cols,
AccType eps = 1e-9) {
extern __shared__ uint8_t _sharedBytes[];
AccType* _shareAccType = (AccType*)_sharedBytes;
AccType N = cols;
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
T* yRow = out + j * cols;
const T* xRow = in + j * cols;
AccType* _sqSum = _shareAccType;
_sqSum[threadIdx.x] = (AccType)0.0f;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
AccType xv = (AccType)xRow[id];
_sqSum[threadIdx.x] += xv * xv;
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
_sqSum[threadIdx.x] += _sqSum[threadIdx.x + skip];
len = (len + 1) >> 1;
}
__syncthreads();
AccType rms = functional::Ops<AccType>::sqrt(_sqSum[0] / N + eps); // all AccType
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
AccType gammav = (AccType)gamma[id];
AccType xv = (AccType)xRow[id];
AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
AccType rmsNorm = xv / rms;
AccType y = gammav * rmsNorm + betav;
yRow[id] = (T)y;
}
}
}
__syncthreads();
}
}
void RMSNormalization(Tensor out,
Tensor in,
Tensor gamma,
Tensor beta,
float eps) {
cudaSetDevice(out->getDeviceId().no);
int rows = in->shape().elements() / in->shape().back();
int cols = in->shape().back();
int blocks = std::min(MAX_BLOCKS, (int)rows);
int threads = std::min(MAX_THREADS, (int)cols);
int shared = threads * sizeof(float);
if(out->type() == Type::float32) {
gRMSNormalization<float, float><<<blocks, threads, shared>>>(out->data<float>(),
in->data<float>(),
gamma->data<float>(),
beta ? beta->data<float>() : nullptr,
rows,
cols,
eps);
#if COMPILE_FP16
} else if (out->type() == Type::float16) {
gRMSNormalization<half, float><<<blocks, threads, shared>>>(out->data<half>(),
in->data<half>(),
gamma->data<half>(),
beta ? beta->data<half>() : nullptr,
rows,
cols,
eps);
#endif
} else {
ABORT("RMSNormalization not implemented for type {}", out->type());
}
}
template <typename T, typename AccType = float>
__global__ void gRMSNormalizationGrad(T* gradX,
T* gradGamma,
T* adj,
T* y,
T* x,
T* gamma,
T* beta,
int rows,
int cols,
AccType eps = 1e-9) {
extern __shared__ uint8_t sharedBytes[];
AccType* shared = (AccType*)sharedBytes;
AccType N = cols;
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
AccType* sum_adj_r = shared; // sum of gradient coming in times layerNorm from value
AccType* sum_sqr = shared + blockDim.x; // sum of x^2
const T* xRow = x + j * cols;
const T* yRow = y + j * cols;
const T* adjRow = adj + j * cols;
sum_adj_r[threadIdx.x] = (AccType)0.0f;
sum_sqr[threadIdx.x] = (AccType)0.0f;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
AccType xv = xRow[id];
AccType yv = yRow[id];
AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
AccType gammav = (AccType)gamma[id];
AccType adjv = adjRow[id];
AccType rv = (yv - betav) / gammav; // go back to RMSNorm(x) from scaled and shifted version for accumulation
sum_adj_r[threadIdx.x] += adjv * rv;
sum_sqr[threadIdx.x] += xv * xv;
}
}
__syncthreads();
int len = blockDim.x;
while(len != 1) {
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1)) {
sum_adj_r[threadIdx.x] += sum_adj_r[threadIdx.x + skip]; // Accumulates in AccType
sum_sqr[threadIdx.x] += sum_sqr[threadIdx.x + skip]; // Accumulates in AccType
}
len = (len + 1) >> 1;
}
__syncthreads();
AccType rms = functional::Ops<AccType>::sqrt(sum_sqr[0] / N + eps);
__syncthreads();
// Jacobian of RMS norm
// J = [ \frac{1}{N * rms} (N\delta_{ij} - RN_i RN_j) ]_{ij}
// J * a = dC/dx_i = ( N a_i - RN_i \sum_j RN_j a_j ) / (N * rms)
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
AccType xv = xRow[id];
AccType gammav = (AccType)gamma[id];
AccType adjv = adjRow[id];
AccType rmsNorm = xv / rms;
AccType gradNorm = N * adjv - rmsNorm * sum_adj_r[0];
gradNorm /= N * rms;
AccType gradXv = gammav * gradNorm;
// Keep RMSN gradient between [-1000, 1000] for TensorOps, this currently used for making values fit into fp16. This wil also clip inf.
// @TODO: to be fixed and removed.
AccType sign = functional::Ops<AccType>::sgn(gradXv);
AccType cutoff = (AccType)1000.f; // @TODO: expose this somehow as an option? or better: make obsolete.
gradXv = functional::Ops<AccType>::abs(gradXv) > cutoff ? sign * cutoff : gradXv; // if gradXv is NaN the value return is NaN too because NaN > value is false.
// @TODO: frankly, this is embarrasing and should rather be removed or optional? It does help for low precision computation though. Maybe turn into option?
gradXv = isnan(gradXv) ? 0.f : gradXv; // turn NaN into 0.
T* gradXRow = gradX + j * cols;
gradXRow[id] += (T)(gradXv);
T* gradGammaRow = gradGamma + j * cols;
// assignment is correct here as this gets summed up
// in the next kernel via matrix product
gradGammaRow[id] = (T)(adjv * rmsNorm);
}
}
}
__syncthreads();
}
}
void RMSNormalizationGrad(Ptr<Allocator> allocator,
Tensor gradX,
Tensor gradGamma,
Tensor gradBeta,
Tensor adj,
Tensor y,
Tensor x,
Tensor gamma,
Tensor beta,
float eps) {
cudaSetDevice(adj->getDeviceId().no);
int rows = y->shape().elements() / y->shape()[-1];
int cols = y->shape()[-1];
int threads = std::min(MAX_THREADS, cols);
int blocks = std::min(MAX_BLOCKS, rows);
auto tempGradGammaMemory = allocator->alloc(adj->memory()->size());
Tensor tempGradGamma = TensorBase::New(tempGradGammaMemory, adj->shape(), adj->type(), adj->getBackend());
tempGradGamma->set(0.f);
auto tempOnesMemory = allocator->alloc(rows * sizeOf(adj->type()));
Tensor tempOnes = TensorBase::New(tempOnesMemory, Shape({1, rows}), adj->type(), adj->getBackend());
tempOnes->set(1.f);
if(gradX->type() == Type::float32) {
int shared = sizeof(float) * threads * 2;
gRMSNormalizationGrad<float, float><<<blocks, threads, shared>>>(
gradX->data<float>(),
tempGradGamma->data<float>(),
adj->data<float>(),
y->data<float>(),
x->data<float>(),
gamma->data<float>(),
(beta) ? beta->data<float>() : nullptr,
rows,
cols,
eps);
#if COMPILE_FP16
} else if (gradX->type() == Type::float16) {
// accumulate in float
int shared = sizeof(float) * threads * 2;
gRMSNormalizationGrad<half, float><<<blocks, threads, shared>>>(
gradX->data<half>(),
tempGradGamma->data<half>(),
adj->data<half>(),
y->data<half>(),
x->data<half>(),
gamma->data<half>(),
(beta) ? beta->data<half>() : nullptr,
rows,
cols,
eps);
#endif
} else {
ABORT("RMSNormalizationGrad not implemented for type {}", gradX->type());
}
// We use this go get rid of the atomicAdd and perform a reduce of the gradients afterwards.
// This is much faster for fp16 which seems to have a broken atomicAdd implementation.
// We reduce bias gradients with a matrix multiply, but use a 32-bit compute type.
// This preserves precision with larger batches where all batch entries reduce into a single vector.
// See also AffineNodeOp where we do the same for biases
gpu::Prod(gradGamma, tempOnes, tempGradGamma, false, false, 1, 1, Type::float32); // beta set to one to add
if(gradBeta) // dC/dbeta = adj - inverse broadcasting (reduction)
gpu::Prod(gradBeta, tempOnes, adj, false, false, 1, 1, Type::float32); // beta set to one to add
allocator->free(tempGradGammaMemory);
allocator->free(tempOnesMemory);
}
template <bool add, typename T>
__global__ void gShift(T* out,
const T* in,

View File

@ -218,6 +218,55 @@ static inline void LayerNormalizationGrad(
cpu::LayerNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
}
// clang-format off
DISPATCH5(RMSNormalization, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, float)
#ifdef CUDA_FOUND
namespace gpu {
void RMSNormalizationGrad(Ptr<Allocator> allocator,
Tensor gradX,
Tensor gradGamma,
Tensor gradBeta,
Tensor adj,
Tensor y,
Tensor x,
Tensor gamma,
Tensor beta,
float eps);
}
#endif
namespace cpu {
void RMSNormalizationGrad(Tensor gradX,
Tensor gradGamma,
Tensor gradBeta,
Tensor adj,
Tensor y,
Tensor x,
Tensor gamma,
Tensor beta,
float eps);
}
static inline void RMSNormalizationGrad(
Ptr<Allocator> allocator,
Tensor gradX,
Tensor gradGamma,
Tensor gradBeta,
Tensor adj,
Tensor y,
Tensor x,
Tensor gamma,
Tensor beta,
float eps) {
#ifdef CUDA_FOUND
if(gradX->getBackend()->getDeviceId().type == DeviceType::gpu)
gpu::RMSNormalizationGrad(allocator, gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
else
#endif
cpu::RMSNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
}
DISPATCH4(HighwayForward, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH7(HighwayBackward, marian::Tensor, marian::Tensor, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)

View File

@ -300,6 +300,49 @@ void tests(DeviceType device, Type floatType = Type::float32) {
}
SECTION("RMS normalization") {
graph->clear();
values.clear();
std::vector<T> init = {
2.88794374, 4.67853451, 3.96257305, 3.28433037,
0.37778997, 0.67662024, 4.24959183, 1.23910618,
0.68929380, 2.00369596, 4.38251686, 1.75624943,
4.96126175, 3.01947117, 4.72057724, 2.23017120
};
auto a1 = graph->param("test1", {2, 2, 4}, inits::fromVector(init));
auto a2 = graph->param("test2", {2, 2, 4}, inits::fromVector(init));
auto gamma = graph->param("gamma", {1, 4}, inits::ones());
auto rms = rmsNorm(a1, gamma, nullptr, 1e-5f);
auto rms2 = gamma * (a2 / sqrt(mean(a2 * a2, /*axis=*/-1) + 1e-5f));
auto top = sum(flatten(rms + rms2));
graph->forward();
graph->backward();
CHECK(rms->shape() == Shape({2, 2, 4}));
std::vector<T> values2;
// compare values of rms and rms2 to make sure forward computation is correct
rms->val()->get(values);
rms2->val()->get(values2);
CHECK( std::equal(values.begin(), values.end(),
values2.begin(), floatApprox) );
// compare adjoints of a1 and a2 (parameters) to makes sure gradient computation is correct
a1->grad()->get(values);
a2->grad()->get(values2);
CHECK( std::equal(values.begin(), values.end(),
values2.begin(), floatApprox) );
}
SECTION("reductions") {
graph->clear();
values.clear();