mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merged PR 18505: RMSNorm on GPU
Support for RMSNorm as drop-in replace for LayerNorm from _Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization_. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`.
This commit is contained in:
parent
a05124176d
commit
caddad90cd
@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Support for RMSNorm as drop-in replace for LayerNorm from `Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization`. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`.
|
||||
- Extend suppression of unwanted output symbols, specifically "\n" from default vocabulary if generated by SentencePiece with byte-fallback. Deactivates with --allow-special
|
||||
- Allow for fine-grained CPU intrinsics overrides when BUILD_ARCH != native e.g. -DBUILD_ARCH=x86-64 -DCOMPILE_AVX512=off
|
||||
- Adds custom bias epilogue kernel.
|
||||
|
@ -208,8 +208,15 @@ void ExpressionGraph::backward(bool reset, float clipValue) {
|
||||
}
|
||||
|
||||
if(v->trainable() && v->marked_for_debug()) {
|
||||
LOG(info, "Debug Grad: {} op={}", v->debug_message(), v->type());
|
||||
LOG(info, v->grad()->debug());
|
||||
Logger log = spdlog::get("general");
|
||||
if(log) {
|
||||
LOG(info, "Debug Grad: {} op={}", v->debug_message(), v->type());
|
||||
LOG(info, v->grad()->debug());
|
||||
}
|
||||
else {
|
||||
std::cerr << "Debug Grad: " << v->debug_message() << " op=" << v->type() << std::endl;
|
||||
std::cerr << v->grad()->debug() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(v->trainable() && clipValue != 0) {
|
||||
|
@ -749,6 +749,18 @@ Expr layerNorm(Expr x,
|
||||
return Expression<LayerNormalizationOp>(nodes, eps);
|
||||
}
|
||||
|
||||
Expr rmsNorm(Expr x,
|
||||
Expr gamma,
|
||||
Expr beta /*= nullptr*/,
|
||||
float eps /*= 1e-9*/) {
|
||||
|
||||
// layerNorm accumulates in float, so small eps is fine
|
||||
std::vector<Expr> nodes = {x, gamma};
|
||||
if(beta)
|
||||
nodes.push_back(beta);
|
||||
return Expression<RMSNormalizationOp>(nodes, eps);
|
||||
}
|
||||
|
||||
Expr highway(Expr y, Expr x, Expr t) {
|
||||
std::vector<Expr> nodes = {y, x, t};
|
||||
return Expression<HighwayNodeOp>(nodes);
|
||||
|
@ -915,6 +915,18 @@ Expr weighted_average(Expr in, Expr weights, int ax = 0);
|
||||
*/
|
||||
Expr layerNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);
|
||||
|
||||
/**
|
||||
* Applies RMS normalization over the last dimension.
|
||||
*
|
||||
* See: Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization.
|
||||
* In Advances in Neural Information Processing Systems 32. Vancouver, Canada.
|
||||
* @f[
|
||||
\frac{x}{\sqrt{\frac{1}{N}\sum x^2 + \mathrm{eps}}} \times \gamma + \beta
|
||||
* @f]
|
||||
* @see RMSNormalizationOp
|
||||
*/
|
||||
Expr rmsNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);
|
||||
|
||||
/**
|
||||
* Highway transformation.
|
||||
* Computes the highway tranform on @p y and @p x as gated by @p t:
|
||||
|
@ -1369,6 +1369,64 @@ private:
|
||||
float eps_;
|
||||
};
|
||||
|
||||
// RMS norm along last axis
|
||||
struct RMSNormalizationOp : public NaryNodeOp {
|
||||
public:
|
||||
RMSNormalizationOp(const std::vector<Expr>& nodes, float eps = 1e-9)
|
||||
: NaryNodeOp(nodes), eps_(eps) {
|
||||
// @TODO: dimension check
|
||||
}
|
||||
|
||||
NodeOps forwardOps() override {
|
||||
return {NodeOp(
|
||||
RMSNormalization(val_,
|
||||
child(0)->val(),
|
||||
child(1)->val(),
|
||||
(children_.size() == 3) ? child(2)->val() : nullptr,
|
||||
eps_))};
|
||||
}
|
||||
|
||||
// @BUGBUG: backward has not been tested for broadcasting gamma/beta
|
||||
NodeOps backwardOps() override {
|
||||
return {NodeOp(
|
||||
RMSNormalizationGrad(
|
||||
graph()->allocator(),
|
||||
child(0)->grad(),
|
||||
child(1)->grad(),
|
||||
(children_.size() == 3) ? child(2)->grad() : nullptr,
|
||||
adj_,
|
||||
val_,
|
||||
child(0)->val(),
|
||||
child(1)->val(),
|
||||
(children_.size() == 3) ? child(2)->val() : nullptr,
|
||||
eps_))};
|
||||
}
|
||||
|
||||
const std::string type() override { return "rms_normalization"; }
|
||||
|
||||
virtual size_t hash() override {
|
||||
size_t seed = NaryNodeOp::hash();
|
||||
util::hash_combine(seed, eps_);
|
||||
return seed;
|
||||
}
|
||||
|
||||
virtual bool equal(Expr node) override {
|
||||
if(!NaryNodeOp::equal(node))
|
||||
return false;
|
||||
auto cnode = std::dynamic_pointer_cast<RMSNormalizationOp>(node);
|
||||
if(!cnode)
|
||||
return false;
|
||||
if(eps_ != cnode->eps_)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class SerializationHelpers; // @TODO: use the same name for this as SqrtNodeOp
|
||||
float eps_;
|
||||
};
|
||||
|
||||
|
||||
struct HighwayNodeOp : public NaryNodeOp {
|
||||
HighwayNodeOp(const std::vector<Expr>& nodes) : NaryNodeOp(nodes) {}
|
||||
|
||||
|
@ -212,4 +212,10 @@ static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = st
|
||||
return marian::layerNorm(x, scale, bias, 1e-6f);
|
||||
}
|
||||
|
||||
static inline Expr rmsNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
|
||||
int dimModel = x->shape()[-1];
|
||||
auto scale = x->graph()->param(prefix + "_rms_scale" + suffix, {1, dimModel}, inits::ones());
|
||||
return marian::rmsNorm(x, scale, nullptr, 1e-6f);
|
||||
}
|
||||
|
||||
} // namespace marian
|
||||
|
@ -176,6 +176,8 @@ public:
|
||||
// layer normalization
|
||||
else if (op == 'n')
|
||||
output = layerNorm(output, prefix, "_pre");
|
||||
else if (op == 'r')
|
||||
output = rmsNorm(output, prefix, "_pre");
|
||||
else
|
||||
ABORT("Unknown pre-processing operation '{}'", op);
|
||||
}
|
||||
@ -201,6 +203,8 @@ public:
|
||||
// layer normalization
|
||||
else if(op == 'n')
|
||||
output = layerNorm(output, prefix);
|
||||
else if(op == 'r')
|
||||
output = rmsNorm(output, prefix);
|
||||
else
|
||||
ABORT("Unknown pre-processing operation '{}'", op);
|
||||
}
|
||||
|
@ -977,7 +977,7 @@ float L2Norm(Tensor in, Ptr<Allocator> /*not used*/) {
|
||||
float sum = 0.f;
|
||||
size_t size = in->size();
|
||||
const float* data = in->data();
|
||||
#pragma omp parallel for simd reduction(+ : sum)
|
||||
#pragma omp parallel for simd reduction(+ : sum)
|
||||
for(size_t i = 0; i < size; ++i) {
|
||||
sum += data[i] * data[i];
|
||||
}
|
||||
@ -998,14 +998,14 @@ void Att(Tensor out_, Tensor va_, Tensor context_, Tensor state_) {
|
||||
int rows = m;
|
||||
int cols = k;
|
||||
|
||||
#pragma omp parallel for
|
||||
#pragma omp parallel for
|
||||
for(int j = 0; j < rows; ++j) {
|
||||
const float* vaRow = va;
|
||||
const float* ctxRow = ctx + (j % (b * t)) * cols;
|
||||
const float* stateRow = state + ((j / (b * t)) * b + j % b) * cols;
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma omp simd reduction(+ : sum)
|
||||
#pragma omp simd reduction(+ : sum)
|
||||
for(int i = 0; i < cols; ++i) {
|
||||
float z = ctxRow[i] + stateRow[i];
|
||||
sum += std::tanh(z) * vaRow[i];
|
||||
@ -1035,7 +1035,7 @@ void AttBack(Tensor gVa_,
|
||||
size_t k = context_->shape()[-1];
|
||||
size_t n = context_->shape()[-2];
|
||||
|
||||
#pragma omp parallel for reduction(+ : gState[:n * k], gVa[:k])
|
||||
#pragma omp parallel for reduction(+ : gState[:n * k], gVa[:k])
|
||||
for(size_t j = 0; j < m; ++j) {
|
||||
float* gcRow = gContext + j * k;
|
||||
float* gsRow = gState + (j % n) * k;
|
||||
@ -1045,7 +1045,7 @@ void AttBack(Tensor gVa_,
|
||||
|
||||
float adj_j = adj[j];
|
||||
|
||||
#pragma omp simd
|
||||
#pragma omp simd
|
||||
for(size_t i = 0; i < k; ++i) {
|
||||
float z = cRow[i] + sRow[i];
|
||||
|
||||
@ -1070,20 +1070,20 @@ void LayerNormalizationImpl(float* out,
|
||||
float eps,
|
||||
int rows,
|
||||
int cols) {
|
||||
#pragma omp parallel for
|
||||
#pragma omp parallel for
|
||||
for(int j = 0; j < rows; ++j) {
|
||||
float* so = out + j * cols;
|
||||
const float* sp = in + j * cols;
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma omp simd reduction(+ : sum)
|
||||
#pragma omp simd reduction(+ : sum)
|
||||
for(int i = 0; i < cols; ++i) {
|
||||
sum += sp[i];
|
||||
}
|
||||
|
||||
float mean = sum / cols;
|
||||
float sqSum = 0.f;
|
||||
#pragma omp simd reduction(+ : sqSum)
|
||||
#pragma omp simd reduction(+ : sqSum)
|
||||
for(int i = 0; i < cols; ++i) {
|
||||
float ex = sp[i] - mean;
|
||||
sqSum += ex * ex;
|
||||
@ -1091,7 +1091,7 @@ void LayerNormalizationImpl(float* out,
|
||||
|
||||
float sigma = std::sqrt(sqSum / cols + eps);
|
||||
|
||||
#pragma omp simd
|
||||
#pragma omp simd
|
||||
for(int i = 0; i < cols; ++i) {
|
||||
float t = alpha[alphaStride * i] * ((sp[i] - mean) / sigma);
|
||||
if(hasBeta)
|
||||
@ -1168,7 +1168,7 @@ void LayerNormalizationGrad(Tensor gradX_,
|
||||
size_t cols = y_->shape()[-1];
|
||||
|
||||
if(beta) {
|
||||
#pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols])
|
||||
#pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols])
|
||||
for(size_t j = 0; j < rows; ++j) {
|
||||
const float* xRow = x + j * cols;
|
||||
const float* yRow = y + j * cols;
|
||||
@ -1180,7 +1180,7 @@ void LayerNormalizationGrad(Tensor gradX_,
|
||||
float sum_adj_x = 0.f;
|
||||
float sum_sqr = 0.f;
|
||||
|
||||
#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
|
||||
#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
|
||||
for(size_t i = 0; i < cols; ++i) {
|
||||
sum_x += xRow[i];
|
||||
sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[betaStride * i] : 0.f)) / gamma[gammaStride * i];
|
||||
@ -1188,14 +1188,14 @@ void LayerNormalizationGrad(Tensor gradX_,
|
||||
}
|
||||
|
||||
float mean = sum_x / cols;
|
||||
#pragma omp simd reduction(+ : sum_sqr)
|
||||
#pragma omp simd reduction(+ : sum_sqr)
|
||||
for(size_t i = 0; i < cols; ++i) {
|
||||
float ex = xRow[i] - mean;
|
||||
sum_sqr += ex * ex;
|
||||
}
|
||||
|
||||
float sigma = std::sqrt(sum_sqr / cols + eps);
|
||||
#pragma omp simd
|
||||
#pragma omp simd
|
||||
for(size_t i = 0; i < cols; ++i) {
|
||||
float grad_x = 0.f;
|
||||
float x_hat = (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i];
|
||||
@ -1209,8 +1209,8 @@ void LayerNormalizationGrad(Tensor gradX_,
|
||||
gradBeta[betaStride * i] += adjRow[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma omp parallel for reduction(+ : gradGamma[:cols])
|
||||
} else { // @TODO: this code duplication is really ugly, but required for omp to work correctly?
|
||||
#pragma omp parallel for reduction(+ : gradGamma[:cols])
|
||||
for(size_t j = 0; j < rows; ++j) {
|
||||
const float* xRow = x + j * cols;
|
||||
const float* yRow = y + j * cols;
|
||||
@ -1222,23 +1222,22 @@ void LayerNormalizationGrad(Tensor gradX_,
|
||||
float sum_adj_x = 0.f;
|
||||
float sum_sqr = 0.f;
|
||||
|
||||
#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
|
||||
#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
|
||||
for(size_t i = 0; i < cols; ++i) {
|
||||
sum_x += xRow[i];
|
||||
sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[betaStride * i] : 0.f)) / gamma[gammaStride * i];
|
||||
// @TODO: beta is NULL here ^^
|
||||
sum_adj_x += adjRow[i] * yRow[i] / gamma[gammaStride * i];
|
||||
sum_adj += adjRow[i];
|
||||
}
|
||||
|
||||
float mean = sum_x / cols;
|
||||
#pragma omp simd reduction(+ : sum_sqr)
|
||||
#pragma omp simd reduction(+ : sum_sqr)
|
||||
for(size_t i = 0; i < cols; ++i) {
|
||||
float ex = xRow[i] - mean;
|
||||
sum_sqr += ex * ex;
|
||||
}
|
||||
|
||||
float sigma = std::sqrt(sum_sqr / cols + eps);
|
||||
#pragma omp simd
|
||||
#pragma omp simd
|
||||
for(size_t i = 0; i < cols; ++i) {
|
||||
float grad_x = 0.f;
|
||||
float x_hat = yRow[i] / gamma[gammaStride * i];
|
||||
@ -1255,6 +1254,163 @@ void LayerNormalizationGrad(Tensor gradX_,
|
||||
}
|
||||
MARIAN_FFAST_MATH_END
|
||||
|
||||
MARIAN_FFAST_MATH_BEGIN
|
||||
template <int alphaStride, int betaStride, bool hasBeta>
|
||||
void RMSNormalizationImpl(float* out,
|
||||
const float* in,
|
||||
const float* alpha,
|
||||
const float* beta,
|
||||
float eps,
|
||||
int rows,
|
||||
int cols) {
|
||||
#pragma omp parallel for
|
||||
for(int j = 0; j < rows; ++j) {
|
||||
float* so = out + j * cols;
|
||||
const float* sp = in + j * cols;
|
||||
|
||||
float sqSum = 0.f;
|
||||
#pragma omp simd reduction(+ : sqSum)
|
||||
for(int i = 0; i < cols; ++i) {
|
||||
sqSum += sp[i] * sp[i];
|
||||
}
|
||||
|
||||
float rms = std::sqrt(sqSum / cols + eps);
|
||||
|
||||
#pragma omp simd
|
||||
for(int i = 0; i < cols; ++i) {
|
||||
float t = alpha[alphaStride * i] * (sp[i] / rms);
|
||||
if(hasBeta)
|
||||
t += beta[betaStride * i];
|
||||
|
||||
so[i] = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
MARIAN_FFAST_MATH_END
|
||||
|
||||
template <int alphaStride>
|
||||
inline void RMSNormalizationDispatchBeta(float* out,
|
||||
const float* in,
|
||||
const float* alpha,
|
||||
Tensor beta,
|
||||
float eps,
|
||||
int rows,
|
||||
int cols) {
|
||||
if (beta) {
|
||||
if (beta->shape().back() > 1) {
|
||||
RMSNormalizationImpl<alphaStride, 1, true>(out, in, alpha, beta->data(), eps, rows, cols);
|
||||
} else {
|
||||
RMSNormalizationImpl<alphaStride, 0, true>(out, in, alpha, beta->data(), eps, rows, cols);
|
||||
}
|
||||
} else {
|
||||
RMSNormalizationImpl<alphaStride, 0, false>(out, in, alpha, nullptr, eps, rows, cols);
|
||||
}
|
||||
}
|
||||
|
||||
void RMSNormalization(Tensor out,
|
||||
Tensor in,
|
||||
Tensor gamma,
|
||||
Tensor beta,
|
||||
float eps) {
|
||||
const float* alpha = gamma->data();
|
||||
const int alphaStride = gamma->shape().back() > 1; // broadcasting for alpha and beta
|
||||
|
||||
int rows = in->shape().elements() / in->shape().back();
|
||||
int cols = in->shape().back();
|
||||
if (alphaStride == 0) {
|
||||
RMSNormalizationDispatchBeta<0>(out->data(), in->data(), alpha, beta, eps, rows, cols);
|
||||
} else {
|
||||
RMSNormalizationDispatchBeta<1>(out->data(), in->data(), alpha, beta, eps, rows, cols);
|
||||
}
|
||||
}
|
||||
|
||||
MARIAN_FFAST_MATH_BEGIN
|
||||
void RMSNormalizationGrad(Tensor gradX_,
|
||||
Tensor gradGamma_,
|
||||
Tensor gradBeta_,
|
||||
Tensor adj_,
|
||||
Tensor y_,
|
||||
Tensor x_,
|
||||
Tensor gamma_,
|
||||
Tensor beta_,
|
||||
float eps) {
|
||||
float* gradX = gradX_->data();
|
||||
float* gradGamma = gradGamma_->data();
|
||||
float* gradBeta = gradBeta_ ? gradBeta_->data() : nullptr;
|
||||
float* adj = adj_->data();
|
||||
float* x = x_->data();
|
||||
float* y = y_->data();
|
||||
float* gamma = gamma_->data();
|
||||
float* beta = beta_ ? beta_->data() : nullptr;
|
||||
// @TODO: The CPU implementation supports scalar gamma and beta. This is a left-over,
|
||||
// we should enable that in the GPU version as well.
|
||||
const int gammaStride = gamma_->shape().back() > 1; // broadcasting for alpha and beta. 0 means it's a scalar
|
||||
const int betaStride = beta_ && beta_->shape().back() > 1;
|
||||
|
||||
size_t rows = y_->shape().elements() / y_->shape()[-1];
|
||||
size_t cols = y_->shape()[-1];
|
||||
|
||||
if(beta) {
|
||||
#pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols])
|
||||
for(size_t j = 0; j < rows; ++j) {
|
||||
const float* xRow = x + j * cols;
|
||||
const float* yRow = y + j * cols;
|
||||
const float* adjRow = adj + j * cols;
|
||||
float* gradXRow = gradX + j * cols;
|
||||
|
||||
float sum_adj_r = 0.f;
|
||||
float sum_sqr = 0.f;
|
||||
|
||||
#pragma omp simd reduction(+ : sum_adj_r, sum_sqr)
|
||||
for(size_t i = 0; i < cols; ++i) {
|
||||
sum_adj_r += adjRow[i] * (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i];
|
||||
sum_sqr += xRow[i] * xRow[i];
|
||||
}
|
||||
|
||||
float rms = std::sqrt(sum_sqr / cols + eps);
|
||||
#pragma omp simd
|
||||
for(size_t i = 0; i < cols; ++i) {
|
||||
float rmsNorm = (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i];
|
||||
float gradNorm = cols * adjRow[i] - rmsNorm * sum_adj_r;
|
||||
gradNorm /= cols * rms;
|
||||
|
||||
gradXRow[i] += gamma[gammaStride * i] * gradNorm;
|
||||
gradGamma[gammaStride * i] += adjRow[i] * rmsNorm;
|
||||
gradBeta[betaStride * i] += adjRow[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma omp parallel for reduction(+ : gradGamma[:cols])
|
||||
for(size_t j = 0; j < rows; ++j) {
|
||||
const float* xRow = x + j * cols;
|
||||
const float* yRow = y + j * cols;
|
||||
const float* adjRow = adj + j * cols;
|
||||
float* gradXRow = gradX + j * cols;
|
||||
|
||||
float sum_adj_r = 0.f;
|
||||
float sum_sqr = 0.f;
|
||||
|
||||
#pragma omp simd reduction(+ : sum_adj_r, sum_sqr)
|
||||
for(size_t i = 0; i < cols; ++i) {
|
||||
sum_adj_r += yRow[i] / gamma[gammaStride * i];
|
||||
sum_sqr += xRow[i] * xRow[i];
|
||||
}
|
||||
|
||||
float rms = std::sqrt(sum_sqr / cols + eps);
|
||||
#pragma omp simd
|
||||
for(size_t i = 0; i < cols; ++i) {
|
||||
float rmsNorm = yRow[i] / gamma[gammaStride * i];
|
||||
float gradNorm = cols * adjRow[i] - rmsNorm * sum_adj_r;
|
||||
gradNorm /= cols * rms;
|
||||
|
||||
gradXRow[i] += gamma[gammaStride * i] * gradNorm;
|
||||
gradGamma[gammaStride * i] += adjRow[i] * rmsNorm;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
MARIAN_FFAST_MATH_END
|
||||
|
||||
void Shift(Tensor out_,
|
||||
Tensor in_,
|
||||
marian::Shape shift,
|
||||
|
@ -2303,6 +2303,273 @@ void LayerNormalizationGrad(Ptr<Allocator> allocator,
|
||||
allocator->free(tempOnesMemory);
|
||||
}
|
||||
|
||||
template <typename T, typename AccType = float>
|
||||
__global__ void gRMSNormalization(T* out,
|
||||
const T* in,
|
||||
const T* gamma,
|
||||
const T* beta,
|
||||
int rows,
|
||||
int cols,
|
||||
AccType eps = 1e-9) {
|
||||
extern __shared__ uint8_t _sharedBytes[];
|
||||
AccType* _shareAccType = (AccType*)_sharedBytes;
|
||||
|
||||
AccType N = cols;
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
T* yRow = out + j * cols;
|
||||
const T* xRow = in + j * cols;
|
||||
|
||||
AccType* _sqSum = _shareAccType;
|
||||
|
||||
_sqSum[threadIdx.x] = (AccType)0.0f;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
AccType xv = (AccType)xRow[id];
|
||||
_sqSum[threadIdx.x] += xv * xv;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int len = blockDim.x;
|
||||
while(len != 1) {
|
||||
__syncthreads();
|
||||
int skip = (len + 1) >> 1;
|
||||
if(threadIdx.x < (len >> 1))
|
||||
_sqSum[threadIdx.x] += _sqSum[threadIdx.x + skip];
|
||||
len = (len + 1) >> 1;
|
||||
}
|
||||
__syncthreads();
|
||||
AccType rms = functional::Ops<AccType>::sqrt(_sqSum[0] / N + eps); // all AccType
|
||||
__syncthreads();
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
AccType gammav = (AccType)gamma[id];
|
||||
AccType xv = (AccType)xRow[id];
|
||||
AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
|
||||
AccType rmsNorm = xv / rms;
|
||||
AccType y = gammav * rmsNorm + betav;
|
||||
yRow[id] = (T)y;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
void RMSNormalization(Tensor out,
|
||||
Tensor in,
|
||||
Tensor gamma,
|
||||
Tensor beta,
|
||||
float eps) {
|
||||
cudaSetDevice(out->getDeviceId().no);
|
||||
|
||||
int rows = in->shape().elements() / in->shape().back();
|
||||
int cols = in->shape().back();
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)rows);
|
||||
int threads = std::min(MAX_THREADS, (int)cols);
|
||||
int shared = threads * sizeof(float);
|
||||
|
||||
if(out->type() == Type::float32) {
|
||||
gRMSNormalization<float, float><<<blocks, threads, shared>>>(out->data<float>(),
|
||||
in->data<float>(),
|
||||
gamma->data<float>(),
|
||||
beta ? beta->data<float>() : nullptr,
|
||||
rows,
|
||||
cols,
|
||||
eps);
|
||||
#if COMPILE_FP16
|
||||
} else if (out->type() == Type::float16) {
|
||||
gRMSNormalization<half, float><<<blocks, threads, shared>>>(out->data<half>(),
|
||||
in->data<half>(),
|
||||
gamma->data<half>(),
|
||||
beta ? beta->data<half>() : nullptr,
|
||||
rows,
|
||||
cols,
|
||||
eps);
|
||||
#endif
|
||||
} else {
|
||||
ABORT("RMSNormalization not implemented for type {}", out->type());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename AccType = float>
|
||||
__global__ void gRMSNormalizationGrad(T* gradX,
|
||||
T* gradGamma,
|
||||
T* adj,
|
||||
T* y,
|
||||
T* x,
|
||||
T* gamma,
|
||||
T* beta,
|
||||
int rows,
|
||||
int cols,
|
||||
AccType eps = 1e-9) {
|
||||
extern __shared__ uint8_t sharedBytes[];
|
||||
AccType* shared = (AccType*)sharedBytes;
|
||||
|
||||
AccType N = cols;
|
||||
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
AccType* sum_adj_r = shared; // sum of gradient coming in times layerNorm from value
|
||||
AccType* sum_sqr = shared + blockDim.x; // sum of x^2
|
||||
|
||||
const T* xRow = x + j * cols;
|
||||
const T* yRow = y + j * cols;
|
||||
const T* adjRow = adj + j * cols;
|
||||
|
||||
sum_adj_r[threadIdx.x] = (AccType)0.0f;
|
||||
sum_sqr[threadIdx.x] = (AccType)0.0f;
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
AccType xv = xRow[id];
|
||||
AccType yv = yRow[id];
|
||||
AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
|
||||
AccType gammav = (AccType)gamma[id];
|
||||
AccType adjv = adjRow[id];
|
||||
AccType rv = (yv - betav) / gammav; // go back to RMSNorm(x) from scaled and shifted version for accumulation
|
||||
|
||||
sum_adj_r[threadIdx.x] += adjv * rv;
|
||||
sum_sqr[threadIdx.x] += xv * xv;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
int len = blockDim.x;
|
||||
while(len != 1) {
|
||||
__syncthreads();
|
||||
int skip = (len + 1) >> 1;
|
||||
if(threadIdx.x < (len >> 1)) {
|
||||
sum_adj_r[threadIdx.x] += sum_adj_r[threadIdx.x + skip]; // Accumulates in AccType
|
||||
sum_sqr[threadIdx.x] += sum_sqr[threadIdx.x + skip]; // Accumulates in AccType
|
||||
}
|
||||
len = (len + 1) >> 1;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
AccType rms = functional::Ops<AccType>::sqrt(sum_sqr[0] / N + eps);
|
||||
__syncthreads();
|
||||
|
||||
// Jacobian of RMS norm
|
||||
// J = [ \frac{1}{N * rms} (N\delta_{ij} - RN_i RN_j) ]_{ij}
|
||||
// J * a = dC/dx_i = ( N a_i - RN_i \sum_j RN_j a_j ) / (N * rms)
|
||||
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
|
||||
AccType xv = xRow[id];
|
||||
AccType gammav = (AccType)gamma[id];
|
||||
AccType adjv = adjRow[id];
|
||||
AccType rmsNorm = xv / rms;
|
||||
|
||||
AccType gradNorm = N * adjv - rmsNorm * sum_adj_r[0];
|
||||
gradNorm /= N * rms;
|
||||
|
||||
AccType gradXv = gammav * gradNorm;
|
||||
|
||||
// Keep RMSN gradient between [-1000, 1000] for TensorOps, this currently used for making values fit into fp16. This wil also clip inf.
|
||||
// @TODO: to be fixed and removed.
|
||||
AccType sign = functional::Ops<AccType>::sgn(gradXv);
|
||||
AccType cutoff = (AccType)1000.f; // @TODO: expose this somehow as an option? or better: make obsolete.
|
||||
gradXv = functional::Ops<AccType>::abs(gradXv) > cutoff ? sign * cutoff : gradXv; // if gradXv is NaN the value return is NaN too because NaN > value is false.
|
||||
|
||||
// @TODO: frankly, this is embarrasing and should rather be removed or optional? It does help for low precision computation though. Maybe turn into option?
|
||||
gradXv = isnan(gradXv) ? 0.f : gradXv; // turn NaN into 0.
|
||||
|
||||
T* gradXRow = gradX + j * cols;
|
||||
gradXRow[id] += (T)(gradXv);
|
||||
|
||||
T* gradGammaRow = gradGamma + j * cols;
|
||||
// assignment is correct here as this gets summed up
|
||||
// in the next kernel via matrix product
|
||||
gradGammaRow[id] = (T)(adjv * rmsNorm);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
void RMSNormalizationGrad(Ptr<Allocator> allocator,
|
||||
Tensor gradX,
|
||||
Tensor gradGamma,
|
||||
Tensor gradBeta,
|
||||
Tensor adj,
|
||||
Tensor y,
|
||||
Tensor x,
|
||||
Tensor gamma,
|
||||
Tensor beta,
|
||||
float eps) {
|
||||
cudaSetDevice(adj->getDeviceId().no);
|
||||
int rows = y->shape().elements() / y->shape()[-1];
|
||||
int cols = y->shape()[-1];
|
||||
|
||||
int threads = std::min(MAX_THREADS, cols);
|
||||
int blocks = std::min(MAX_BLOCKS, rows);
|
||||
|
||||
auto tempGradGammaMemory = allocator->alloc(adj->memory()->size());
|
||||
Tensor tempGradGamma = TensorBase::New(tempGradGammaMemory, adj->shape(), adj->type(), adj->getBackend());
|
||||
tempGradGamma->set(0.f);
|
||||
|
||||
auto tempOnesMemory = allocator->alloc(rows * sizeOf(adj->type()));
|
||||
Tensor tempOnes = TensorBase::New(tempOnesMemory, Shape({1, rows}), adj->type(), adj->getBackend());
|
||||
tempOnes->set(1.f);
|
||||
|
||||
if(gradX->type() == Type::float32) {
|
||||
int shared = sizeof(float) * threads * 2;
|
||||
gRMSNormalizationGrad<float, float><<<blocks, threads, shared>>>(
|
||||
gradX->data<float>(),
|
||||
tempGradGamma->data<float>(),
|
||||
adj->data<float>(),
|
||||
y->data<float>(),
|
||||
x->data<float>(),
|
||||
gamma->data<float>(),
|
||||
(beta) ? beta->data<float>() : nullptr,
|
||||
rows,
|
||||
cols,
|
||||
eps);
|
||||
#if COMPILE_FP16
|
||||
} else if (gradX->type() == Type::float16) {
|
||||
// accumulate in float
|
||||
int shared = sizeof(float) * threads * 2;
|
||||
gRMSNormalizationGrad<half, float><<<blocks, threads, shared>>>(
|
||||
gradX->data<half>(),
|
||||
tempGradGamma->data<half>(),
|
||||
adj->data<half>(),
|
||||
y->data<half>(),
|
||||
x->data<half>(),
|
||||
gamma->data<half>(),
|
||||
(beta) ? beta->data<half>() : nullptr,
|
||||
rows,
|
||||
cols,
|
||||
eps);
|
||||
#endif
|
||||
} else {
|
||||
ABORT("RMSNormalizationGrad not implemented for type {}", gradX->type());
|
||||
}
|
||||
|
||||
// We use this go get rid of the atomicAdd and perform a reduce of the gradients afterwards.
|
||||
// This is much faster for fp16 which seems to have a broken atomicAdd implementation.
|
||||
// We reduce bias gradients with a matrix multiply, but use a 32-bit compute type.
|
||||
// This preserves precision with larger batches where all batch entries reduce into a single vector.
|
||||
// See also AffineNodeOp where we do the same for biases
|
||||
gpu::Prod(gradGamma, tempOnes, tempGradGamma, false, false, 1, 1, Type::float32); // beta set to one to add
|
||||
|
||||
if(gradBeta) // dC/dbeta = adj - inverse broadcasting (reduction)
|
||||
gpu::Prod(gradBeta, tempOnes, adj, false, false, 1, 1, Type::float32); // beta set to one to add
|
||||
|
||||
allocator->free(tempGradGammaMemory);
|
||||
allocator->free(tempOnesMemory);
|
||||
}
|
||||
|
||||
|
||||
template <bool add, typename T>
|
||||
__global__ void gShift(T* out,
|
||||
const T* in,
|
||||
|
@ -218,6 +218,55 @@ static inline void LayerNormalizationGrad(
|
||||
cpu::LayerNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
DISPATCH5(RMSNormalization, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, float)
|
||||
|
||||
#ifdef CUDA_FOUND
|
||||
namespace gpu {
|
||||
void RMSNormalizationGrad(Ptr<Allocator> allocator,
|
||||
Tensor gradX,
|
||||
Tensor gradGamma,
|
||||
Tensor gradBeta,
|
||||
Tensor adj,
|
||||
Tensor y,
|
||||
Tensor x,
|
||||
Tensor gamma,
|
||||
Tensor beta,
|
||||
float eps);
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace cpu {
|
||||
void RMSNormalizationGrad(Tensor gradX,
|
||||
Tensor gradGamma,
|
||||
Tensor gradBeta,
|
||||
Tensor adj,
|
||||
Tensor y,
|
||||
Tensor x,
|
||||
Tensor gamma,
|
||||
Tensor beta,
|
||||
float eps);
|
||||
}
|
||||
|
||||
static inline void RMSNormalizationGrad(
|
||||
Ptr<Allocator> allocator,
|
||||
Tensor gradX,
|
||||
Tensor gradGamma,
|
||||
Tensor gradBeta,
|
||||
Tensor adj,
|
||||
Tensor y,
|
||||
Tensor x,
|
||||
Tensor gamma,
|
||||
Tensor beta,
|
||||
float eps) {
|
||||
#ifdef CUDA_FOUND
|
||||
if(gradX->getBackend()->getDeviceId().type == DeviceType::gpu)
|
||||
gpu::RMSNormalizationGrad(allocator, gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
|
||||
else
|
||||
#endif
|
||||
cpu::RMSNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
|
||||
}
|
||||
|
||||
DISPATCH4(HighwayForward, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)
|
||||
DISPATCH7(HighwayBackward, marian::Tensor, marian::Tensor, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)
|
||||
|
||||
|
@ -300,6 +300,49 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
|
||||
}
|
||||
|
||||
SECTION("RMS normalization") {
|
||||
graph->clear();
|
||||
values.clear();
|
||||
|
||||
std::vector<T> init = {
|
||||
2.88794374, 4.67853451, 3.96257305, 3.28433037,
|
||||
0.37778997, 0.67662024, 4.24959183, 1.23910618,
|
||||
0.68929380, 2.00369596, 4.38251686, 1.75624943,
|
||||
4.96126175, 3.01947117, 4.72057724, 2.23017120
|
||||
};
|
||||
|
||||
auto a1 = graph->param("test1", {2, 2, 4}, inits::fromVector(init));
|
||||
auto a2 = graph->param("test2", {2, 2, 4}, inits::fromVector(init));
|
||||
auto gamma = graph->param("gamma", {1, 4}, inits::ones());
|
||||
|
||||
auto rms = rmsNorm(a1, gamma, nullptr, 1e-5f);
|
||||
auto rms2 = gamma * (a2 / sqrt(mean(a2 * a2, /*axis=*/-1) + 1e-5f));
|
||||
|
||||
auto top = sum(flatten(rms + rms2));
|
||||
|
||||
graph->forward();
|
||||
graph->backward();
|
||||
|
||||
CHECK(rms->shape() == Shape({2, 2, 4}));
|
||||
|
||||
std::vector<T> values2;
|
||||
|
||||
// compare values of rms and rms2 to make sure forward computation is correct
|
||||
rms->val()->get(values);
|
||||
rms2->val()->get(values2);
|
||||
|
||||
CHECK( std::equal(values.begin(), values.end(),
|
||||
values2.begin(), floatApprox) );
|
||||
|
||||
// compare adjoints of a1 and a2 (parameters) to makes sure gradient computation is correct
|
||||
a1->grad()->get(values);
|
||||
a2->grad()->get(values2);
|
||||
|
||||
CHECK( std::equal(values.begin(), values.end(),
|
||||
values2.begin(), floatApprox) );
|
||||
|
||||
}
|
||||
|
||||
SECTION("reductions") {
|
||||
graph->clear();
|
||||
values.clear();
|
||||
|
Loading…
Reference in New Issue
Block a user