mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-19 02:37:14 +03:00
Adds better Affine support for GPUs when using CUDA 11. Introduces a new bias addition kernel for CUDA < 11 (#778)
Co-authored-by: Marcin Junczys-Dowmunt <marcinjd@microsoft.com>
This commit is contained in:
parent
0223ce90b1
commit
fddd0e0661
@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Adds custom bias epilogue kernel.
|
||||
- Adds support for fusing relu and bias addition into gemms when using cuda 11.
|
||||
- Better suppression of unwanted output symbols, specifically "\n" from SentencePiece with byte-fallback. Can be deactivated with --allow-special
|
||||
- Display decoder time statistics with marian-decoder --stat-freq 10 ...
|
||||
- Support for MS-internal binary shortlist
|
||||
|
@ -347,8 +347,20 @@ if(CUDA_FOUND)
|
||||
endif()
|
||||
message(STATUS "Found CUDA libraries: ${CUDA_LIBS}")
|
||||
else(USE_STATIC_LIBS)
|
||||
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||
# We actually only need cublasLt here after cuda 11. Marian will work fine without it pre cuda 11. We want to force CMake to use the cublas
|
||||
# version that ships with CUDA 11 so we force the search to occur inside of the cuda toolkit directory.
|
||||
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||
if ((CUDA_VERSION VERSION_EQUAL "11.0" OR CUDA_VERSION VERSION_GREATER "11.0"))
|
||||
find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 NO_DEFAULT_PATH)
|
||||
if(NOT CUDA_cublasLt_LIBRARY)
|
||||
message(FATAL_ERROR "cuBLASLt library not found")
|
||||
endif()
|
||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_cublasLt_LIBRARY})
|
||||
set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_cublasLt_LIBRARY})
|
||||
endif()
|
||||
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
|
||||
message(STATUS "Found CUDA libraries: ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}")
|
||||
message(STATUS "Found CUDA libraries: ${CUDA_LIBS}")
|
||||
endif(USE_STATIC_LIBS)
|
||||
|
||||
if(USE_CUDNN)
|
||||
|
@ -175,6 +175,7 @@ if(CUDA_FOUND)
|
||||
tensors/gpu/device.cu
|
||||
tensors/gpu/algorithm.cu
|
||||
tensors/gpu/prod.cpp
|
||||
tensors/gpu/prod.cu
|
||||
tensors/gpu/prod_sparse.cpp
|
||||
tensors/gpu/topk.cu
|
||||
tensors/gpu/element.cu
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "graph/expression_operators.h"
|
||||
#include "common/definitions.h"
|
||||
#include "layers/constructors.h"
|
||||
|
||||
#include "graph/node_operators.h"
|
||||
@ -518,7 +519,7 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) {
|
||||
return Expression<DotBatchedNodeOp>(a, b, transA, transB, scale);
|
||||
}
|
||||
|
||||
static Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
|
||||
Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
|
||||
// general version, MKL, CBlas or CUDA
|
||||
|
||||
int rows = a->shape().elements() / a->shape()[-1];
|
||||
@ -577,6 +578,15 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
|
||||
}
|
||||
}
|
||||
|
||||
Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
|
||||
auto graph = a->graph();
|
||||
|
||||
if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu)
|
||||
return Expression<AffineWithReluNodeOp>(a, b, bias, transA, transB, scale);
|
||||
else
|
||||
return relu(affine(a, b, bias, transA, transB, scale));
|
||||
}
|
||||
|
||||
// @TODO: Not a great place to check this
|
||||
#if CUDA_VERSION < 11000
|
||||
// multiply a CSR matrix A with a matrix B
|
||||
|
@ -488,11 +488,21 @@ Expr bdot(Expr a,
|
||||
*/
|
||||
Expr affine(Expr a,
|
||||
Expr b,
|
||||
Expr c,
|
||||
Expr bias,
|
||||
bool transA = false,
|
||||
bool transB = false,
|
||||
float scalar = 1.f);
|
||||
|
||||
/**
|
||||
* As above, but efficiently applies relu transformation to output. For inference only.
|
||||
*/
|
||||
Expr affineWithRelu(Expr a,
|
||||
Expr b,
|
||||
Expr bias,
|
||||
bool transA = false,
|
||||
bool transB = false,
|
||||
float scalar = 1.f);
|
||||
|
||||
/**
|
||||
* Computes the dot product of CSR-tensor @p A with @p B.
|
||||
*/
|
||||
|
@ -266,17 +266,18 @@ public:
|
||||
|
||||
NodeOps forwardOps() override {
|
||||
using namespace functional;
|
||||
|
||||
|
||||
return {
|
||||
NodeOp(
|
||||
Prod(val_,
|
||||
child(0)->val(),
|
||||
child(1)->val(),
|
||||
transA_,
|
||||
transB_,
|
||||
0.f,
|
||||
scalar_);
|
||||
Prod(val_, child(3)->val(), child(2)->val(), false, false, 1.f, 1.f))
|
||||
NodeOp(Affine(val_,
|
||||
graph()->allocator(),
|
||||
child(0)->val(),
|
||||
child(1)->val(),
|
||||
child(2)->val(),
|
||||
transA_,
|
||||
transB_,
|
||||
0.f,
|
||||
scalar_,
|
||||
/*doRelu=*/false))
|
||||
};
|
||||
}
|
||||
|
||||
@ -323,8 +324,7 @@ public:
|
||||
false,
|
||||
1.0,
|
||||
scalar_, computeTypeB)),
|
||||
NodeOp(Prod(
|
||||
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
|
||||
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
|
||||
};
|
||||
|
||||
if(transA_ && !transB_)
|
||||
@ -343,8 +343,7 @@ public:
|
||||
false,
|
||||
1.0,
|
||||
scalar_, computeTypeB)),
|
||||
NodeOp(Prod(
|
||||
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
|
||||
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
|
||||
};
|
||||
|
||||
if(transA_ && transB_)
|
||||
@ -363,8 +362,7 @@ public:
|
||||
true,
|
||||
1.0,
|
||||
scalar_, computeTypeB)),
|
||||
NodeOp(Prod(
|
||||
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
|
||||
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
|
||||
};
|
||||
|
||||
return {
|
||||
@ -382,8 +380,7 @@ public:
|
||||
false,
|
||||
1.0,
|
||||
scalar_, computeTypeB)),
|
||||
NodeOp(Prod(
|
||||
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
|
||||
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
|
||||
};
|
||||
}
|
||||
|
||||
@ -414,6 +411,97 @@ public:
|
||||
|
||||
};
|
||||
|
||||
class AffineWithReluNodeOp : public NaryNodeOp {
|
||||
private:
|
||||
friend class SerializationHelpers;
|
||||
bool transA_;
|
||||
bool transB_;
|
||||
float scalar_;
|
||||
|
||||
public:
|
||||
AffineWithReluNodeOp(Expr a,
|
||||
Expr b,
|
||||
Expr bias,
|
||||
bool transA,
|
||||
bool transB,
|
||||
float scalar)
|
||||
: NaryNodeOp({a, b, bias}, newShape(a, b, transA, transB)),
|
||||
transA_(transA),
|
||||
transB_(transB),
|
||||
scalar_(scalar) {
|
||||
ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu,
|
||||
"AffineWithReluNodeOp currently only supported for inference on GPU");
|
||||
}
|
||||
|
||||
Shape newShape(Expr a, Expr b, bool transA, bool transB) {
|
||||
auto shapeA = a->shape();
|
||||
if(transA) {
|
||||
shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
|
||||
shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
|
||||
}
|
||||
|
||||
auto shapeB = b->shape();
|
||||
if(transB) {
|
||||
shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]);
|
||||
shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]);
|
||||
}
|
||||
|
||||
Shape outShape = shapeA;
|
||||
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
|
||||
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
|
||||
"Matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
|
||||
return outShape;
|
||||
}
|
||||
|
||||
NodeOps forwardOps() override {
|
||||
ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu,
|
||||
"AffineWithReluNodeOp currently only supported for inference on GPU");
|
||||
|
||||
return {
|
||||
NodeOp(Affine(val_,
|
||||
graph()->allocator(),
|
||||
child(0)->val(),
|
||||
child(1)->val(),
|
||||
child(2)->val(),
|
||||
transA_,
|
||||
transB_,
|
||||
0.f,
|
||||
scalar_,
|
||||
/*doRelu=*/true))
|
||||
};
|
||||
}
|
||||
|
||||
NodeOps backwardOps() override {
|
||||
ABORT("AffineWithReluNodeOp cannot be used for training??");
|
||||
return {};
|
||||
}
|
||||
|
||||
const std::string type() override { return "affineWithRelu"; }
|
||||
|
||||
virtual size_t hash() override {
|
||||
size_t seed = NaryNodeOp::hash();
|
||||
util::hash_combine(seed, transA_);
|
||||
util::hash_combine(seed, transB_);
|
||||
util::hash_combine(seed, scalar_);
|
||||
return seed;
|
||||
}
|
||||
|
||||
virtual bool equal(Expr node) override {
|
||||
if(!NaryNodeOp::equal(node))
|
||||
return false;
|
||||
auto cnode = std::dynamic_pointer_cast<AffineWithReluNodeOp>(node);
|
||||
if(!cnode)
|
||||
return false;
|
||||
if(transA_ != cnode->transA_)
|
||||
return false;
|
||||
if(transB_ != cnode->transB_)
|
||||
return false;
|
||||
if(scalar_ != cnode->scalar_)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class DotBatchedNodeOp : public NaryNodeOp {
|
||||
private:
|
||||
friend class SerializationHelpers;
|
||||
|
@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common/definitions.h"
|
||||
#include "graph/expression_operators.h"
|
||||
#include "marian.h"
|
||||
|
||||
#include "data/shortlist.h"
|
||||
@ -168,22 +170,37 @@ public:
|
||||
// --- a few layers with built-in parameters created on the fly, without proper object
|
||||
// @TODO: change to a proper layer object
|
||||
|
||||
static inline std::function<Expr(Expr)> activationByName(const std::string& actName) {
|
||||
if (actName == "relu")
|
||||
return (ActivationFunction*)relu;
|
||||
else if (actName == "swish")
|
||||
return (ActivationFunction*)swish;
|
||||
else if (actName == "gelu")
|
||||
return (ActivationFunction*)gelu;
|
||||
else if (actName == "") // return identity function if activation name is empty
|
||||
return [](Expr x) { return x; };
|
||||
ABORT("Invalid activation name '{}'", actName);
|
||||
}
|
||||
|
||||
// like affine() but with built-in parameters, activation, and dropout
|
||||
static inline Expr denseInline(Expr x,
|
||||
std::string prefix,
|
||||
std::string suffix,
|
||||
int outDim,
|
||||
Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
|
||||
const std::function<Expr(Expr)>& actFn = nullptr,
|
||||
std::string actName = "",
|
||||
float dropProb = 0.0f) {
|
||||
auto graph = x->graph();
|
||||
|
||||
auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, inits::glorotUniform());
|
||||
auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros());
|
||||
|
||||
x = affine(x, W, b);
|
||||
if(actFn)
|
||||
x = actFn(x);
|
||||
if(actName == "relu") {
|
||||
x = affineWithRelu(x, W, b); // speed optimization for inference, @TODO: handle better in future layer framework
|
||||
} else {
|
||||
x = affine(x, W, b);
|
||||
x = activationByName(actName)(x);
|
||||
}
|
||||
x = dropout(x, dropProb); // @TODO: check for infernce?
|
||||
return x;
|
||||
}
|
||||
|
@ -170,7 +170,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
|
||||
/*suffix=*/"1",
|
||||
ffnDim,
|
||||
inits::glorotUniform(),
|
||||
(ActivationFunction*)relu,
|
||||
"relu",
|
||||
ffnDropProb);
|
||||
f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
|
||||
// add & norm
|
||||
|
@ -396,18 +396,6 @@ public:
|
||||
opt<int>("transformer-heads"), /*cache=*/false);
|
||||
}
|
||||
|
||||
static inline
|
||||
std::function<Expr(Expr)> activationByName(const std::string& actName)
|
||||
{
|
||||
if (actName == "relu")
|
||||
return (ActivationFunction*)relu;
|
||||
else if (actName == "swish")
|
||||
return (ActivationFunction*)swish;
|
||||
else if (actName == "gelu")
|
||||
return (ActivationFunction*)gelu;
|
||||
ABORT("Invalid activation name '{}'", actName);
|
||||
}
|
||||
|
||||
Expr LayerFFN(std::string prefix, Expr input) const {
|
||||
int dimModel = input->shape()[-1];
|
||||
|
||||
@ -415,9 +403,9 @@ public:
|
||||
auto opsPre = opt<std::string>("transformer-preprocess");
|
||||
auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb);
|
||||
|
||||
auto actName = opt<std::string>("transformer-ffn-activation");
|
||||
int dimFfn = opt<int>("transformer-dim-ffn");
|
||||
int depthFfn = opt<int>("transformer-ffn-depth");
|
||||
auto actFn = activationByName(opt<std::string>("transformer-ffn-activation"));
|
||||
float ffnDropProb
|
||||
= inference_ ? 0 : opt<float>("transformer-dropout-ffn");
|
||||
|
||||
@ -427,12 +415,11 @@ public:
|
||||
|
||||
// the stack of FF layers
|
||||
for(int i = 1; i < depthFfn; ++i)
|
||||
output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, initFn, actFn, ffnDropProb);
|
||||
output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, initFn, actName, ffnDropProb);
|
||||
output = denseInline(output, prefix, /*suffix=*/std::to_string(depthFfn), dimModel, initFn);
|
||||
|
||||
auto opsPost = opt<std::string>("transformer-postprocess");
|
||||
output
|
||||
= postProcess(prefix + "_ffn", opsPost, output, input, dropProb);
|
||||
output = postProcess(prefix + "_ffn", opsPost, output, input, dropProb);
|
||||
|
||||
return output;
|
||||
}
|
||||
@ -450,21 +437,21 @@ public:
|
||||
// FFN
|
||||
int dimAan = opt<int>("transformer-dim-aan");
|
||||
int depthAan = opt<int>("transformer-aan-depth");
|
||||
auto actFn = activationByName(opt<std::string>("transformer-aan-activation"));
|
||||
auto actName = opt<std::string>("transformer-aan-activation");
|
||||
float aanDropProb = inference_ ? 0 : opt<float>("transformer-dropout-ffn");
|
||||
|
||||
auto initFn = inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f);
|
||||
|
||||
// the stack of AAN layers
|
||||
for(int i = 1; i < depthAan; ++i)
|
||||
y = denseInline(y, prefix, /*suffix=*/std::to_string(i), dimAan, initFn, actFn, aanDropProb);
|
||||
y = denseInline(y, prefix, /*suffix=*/std::to_string(i), dimAan, initFn, actName, aanDropProb);
|
||||
if(y->shape()[-1] != dimModel) // bring it back to the desired dimension if needed
|
||||
y = denseInline(y, prefix, std::to_string(depthAan), dimModel, initFn);
|
||||
|
||||
bool noGate = opt<bool>("transformer-aan-nogate");
|
||||
if(!noGate) {
|
||||
auto gi = denseInline(x, prefix, /*suffix=*/"i", dimModel, initFn, (ActivationFunction*)sigmoid);
|
||||
auto gf = denseInline(y, prefix, /*suffix=*/"f", dimModel, initFn, (ActivationFunction*)sigmoid);
|
||||
auto gi = denseInline(x, prefix, /*suffix=*/"i", dimModel, initFn, "sigmoid");
|
||||
auto gf = denseInline(y, prefix, /*suffix=*/"f", dimModel, initFn, "sigmoid");
|
||||
y = gi * x + gf * y;
|
||||
}
|
||||
|
||||
|
@ -212,6 +212,23 @@ void ProdWithBias(marian::Tensor C,
|
||||
cpu::integer::AddBias(C, bias);
|
||||
}
|
||||
|
||||
void Affine(marian::Tensor C,
|
||||
Ptr<Allocator> /*allocator*/,
|
||||
const marian::Tensor& A,
|
||||
const marian::Tensor& B,
|
||||
const marian::Tensor& bias,
|
||||
bool transA,
|
||||
bool transB,
|
||||
float beta,
|
||||
float scalar,
|
||||
bool reluPostprocess) {
|
||||
using namespace functional;
|
||||
ProdWithBias(C, A, B, bias, transA, transB, beta, scalar);
|
||||
if(reluPostprocess)
|
||||
cpu::Element(_1 = ReLU(_1), C); // @TODO: also fuse with AddBias
|
||||
}
|
||||
|
||||
|
||||
void CSRProd(marian::Tensor C,
|
||||
Ptr<Allocator> /*allocator*/,
|
||||
const marian::Tensor& S_values,
|
||||
|
@ -152,6 +152,30 @@
|
||||
cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9); \
|
||||
}
|
||||
|
||||
#define DISPATCH10( \
|
||||
Function, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10) \
|
||||
namespace gpu { \
|
||||
void Function(Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10); \
|
||||
} \
|
||||
namespace cpu { \
|
||||
void Function(Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10); \
|
||||
} \
|
||||
static inline void Function(Arg1 arg1, \
|
||||
Arg2 arg2, \
|
||||
Arg3 arg3, \
|
||||
Arg4 arg4, \
|
||||
Arg5 arg5, \
|
||||
Arg6 arg6, \
|
||||
Arg7 arg7, \
|
||||
Arg8 arg8, \
|
||||
Arg9 arg9, \
|
||||
Arg10 arg10) { \
|
||||
if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \
|
||||
gpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10); \
|
||||
else \
|
||||
cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10); \
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define DISPATCH1(Function, Arg1) \
|
||||
@ -248,4 +272,22 @@
|
||||
cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9); \
|
||||
}
|
||||
|
||||
#define DISPATCH10( \
|
||||
Function, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10) \
|
||||
namespace cpu { \
|
||||
void Function(Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10); \
|
||||
} \
|
||||
static inline void Function(Arg1 arg1, \
|
||||
Arg2 arg2, \
|
||||
Arg3 arg3, \
|
||||
Arg4 arg4, \
|
||||
Arg5 arg5, \
|
||||
Arg6 arg6, \
|
||||
Arg7 arg7, \
|
||||
Arg8 arg8, \
|
||||
Arg9 arg9, \
|
||||
Arg10 arg10) { \
|
||||
cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10); \
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -11,10 +11,34 @@
|
||||
#include "tensors/gpu/cuda_helpers.h"
|
||||
// clang-format on
|
||||
|
||||
#if CUDA_VERSION >= 11000
|
||||
#include <cublasLt.h>
|
||||
#endif
|
||||
|
||||
namespace marian {
|
||||
|
||||
namespace gpu {
|
||||
|
||||
// It seems that the bias must be 8 byte aligned for the cublasLt epilogue to work. Therefore,
|
||||
// if the bias pointer is not 8 byte aligned, we do a normal matmul in cublasLt and invoke a
|
||||
// custom epilogue kernel.
|
||||
static constexpr int REQUIRED_BIAS_ALIGNMENT = 8;
|
||||
|
||||
// Used to set preferences for cublasLt to filter out algos if matrices to not meet default 256 byte alignment
|
||||
int getAlignmentUpTo256(const void *ptr) {
|
||||
uintptr_t addr = (uintptr_t)ptr;
|
||||
int trailingZeros = 0;
|
||||
|
||||
for(int shiftAmt = 8, mask = 0xFF; shiftAmt > 0; shiftAmt /= 2, mask >>=shiftAmt) {
|
||||
if ((addr & mask) == 0) {
|
||||
trailingZeros += shiftAmt;
|
||||
addr >>= shiftAmt;
|
||||
}
|
||||
}
|
||||
|
||||
return std::min(256, 1 << trailingZeros);
|
||||
}
|
||||
|
||||
// The explicit version of matmult like cublasGemmEx choose their math mode based on the algorithm that
|
||||
// has been passed into the function call and seem to ignore setMathMode. Here we query the used math mode
|
||||
// to choose the algorithm.
|
||||
@ -412,5 +436,198 @@ void ProdBatched(marian::Tensor C,
|
||||
}
|
||||
}
|
||||
|
||||
#if CUDA_VERSION >= 11000 // Earlier versions of cublasLT do not support bias addition for fp32 and fp16.
|
||||
|
||||
static cublasStatus_t cublasLtAffineHelper(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB,
|
||||
cudaDataType matrixType,
|
||||
int m, int n, int k, const void *alpha, const void *A, int lda, const void *B,
|
||||
int ldb, const void *beta, void *C, int ldc, const void* bias,
|
||||
void* workspace, size_t workspaceSize, bool do_relu, cudaStream_t stream) {
|
||||
|
||||
cublasLtMatmulDesc_t operationDesc = NULL;
|
||||
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
|
||||
cublasLtMatmulPreference_t preference = NULL;
|
||||
|
||||
int returnedResults = 0;
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
|
||||
cublasLtEpilogue_t epilogue = do_relu? CUBLASLT_EPILOGUE_RELU_BIAS: CUBLASLT_EPILOGUE_BIAS;
|
||||
cublasComputeType_t computeType = matrixType == CUDA_R_32F? CUBLAS_COMPUTE_32F_FAST_16F: CUBLAS_COMPUTE_16F;
|
||||
|
||||
// If the bias is not aligned, just matmul and invoke custom epilogue later.
|
||||
// cublas fails with a misalignment error if this condition is not true.
|
||||
if((uintptr_t)bias % REQUIRED_BIAS_ALIGNMENT != 0) {
|
||||
epilogue = CUBLASLT_EPILOGUE_DEFAULT;
|
||||
}
|
||||
|
||||
CUBLAS_CHECK(cublasLtMatmulDescCreate(&operationDesc, computeType, matrixType));
|
||||
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transA, sizeof(transA)));
|
||||
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transB, sizeof(transB)));
|
||||
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
|
||||
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
|
||||
|
||||
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&Adesc, matrixType, transA == CUBLAS_OP_N ? m : k, transA == CUBLAS_OP_N ? k : m, lda));
|
||||
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&Bdesc, matrixType, transB == CUBLAS_OP_N ? k : n, transB == CUBLAS_OP_N ? n : k, ldb));
|
||||
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&Cdesc, matrixType, m, n, ldc));
|
||||
|
||||
// I think we need to do this since we can slice matrices...
|
||||
// The allocator always allocates on 256 byte boundaries but we have no guarantees about the alignment of a matrix slice so we filter out
|
||||
// algorithms that would not work with matrices not aligned to 256 bytes.
|
||||
int alignmentA = getAlignmentUpTo256(A);
|
||||
int alignmentB = getAlignmentUpTo256(B);
|
||||
int alignmentC = getAlignmentUpTo256(C);
|
||||
|
||||
CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));
|
||||
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
|
||||
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &alignmentA, sizeof(alignmentA)));
|
||||
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &alignmentB, sizeof(alignmentB)));
|
||||
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &alignmentC, sizeof(alignmentC)));
|
||||
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &alignmentC, sizeof(alignmentC)));
|
||||
CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, preference, 1, &heuristicResult, &returnedResults));
|
||||
|
||||
cublasStatus_t opStatus = cublasLtMatmul(ltHandle, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc,
|
||||
&heuristicResult.algo, workspace, workspaceSize, stream);
|
||||
|
||||
if (preference) CUBLAS_CHECK(cublasLtMatmulPreferenceDestroy(preference));
|
||||
if (Cdesc) CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(Cdesc));
|
||||
if (Bdesc) CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(Bdesc));
|
||||
if (Adesc) CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(Adesc));
|
||||
if (operationDesc) CUBLAS_CHECK(cublasLtMatmulDescDestroy(operationDesc));
|
||||
|
||||
return opStatus;
|
||||
}
|
||||
|
||||
static cublasStatus_t cublasLtAffineTyped(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB,
|
||||
int m, int n, int k, const half *alpha, const half *A, int lda, const half *B,
|
||||
int ldb, const half *beta, half *C, int ldc, const half* bias,
|
||||
half* workspace, size_t workspaceSizeBytes, bool do_relu, cudaStream_t stream) {
|
||||
return cublasLtAffineHelper(ltHandle, transA, transB, CUDA_R_16F, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, bias,
|
||||
workspace, workspaceSizeBytes, do_relu, stream);
|
||||
}
|
||||
|
||||
static cublasStatus_t cublasLtAffineTyped(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB,
|
||||
int m, int n, int k, const float *alpha, const float *A, int lda, const float *B,
|
||||
int ldb, const float *beta, float *C, int ldc, const float* bias,
|
||||
float* workspace, size_t workspaceSizeBytes,bool do_relu, cudaStream_t stream) {
|
||||
|
||||
return cublasLtAffineHelper(ltHandle, transA, transB, CUDA_R_32F, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, bias,
|
||||
workspace, workspaceSizeBytes, do_relu, stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void affineTyped(marian::Tensor C, Ptr<Allocator> allocator, const marian::Tensor& A, const marian::Tensor& B, const marian::Tensor& bias,
|
||||
bool transA, bool transB, T beta, T scalar, bool do_relu) {
|
||||
|
||||
CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no));
|
||||
T alpha = scalar;
|
||||
|
||||
int m = A->shape().elements() / A->shape().back();
|
||||
int k = A->shape().back();
|
||||
if(transA)
|
||||
std::swap(m, k);
|
||||
|
||||
int l = B->shape().elements() / B->shape().back();
|
||||
int n = B->shape().back();
|
||||
if(transB)
|
||||
std::swap(l, n);
|
||||
|
||||
int lda = A->shape().back();
|
||||
int ldb = B->shape().back();
|
||||
int ldc = B->shape().back();
|
||||
|
||||
size_t bias_size = bias->shape().elements();
|
||||
ABORT_IF(n != bias_size, "The number of elements in the bias must match the number of columns in C");
|
||||
|
||||
if(transB)
|
||||
ldc = B->shape().elements() / B->shape().back();
|
||||
|
||||
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
|
||||
auto backend = std::static_pointer_cast<gpu::Backend>(C->getBackend());
|
||||
auto cublasHandle = backend->getCublasHandle();
|
||||
auto ltHandle = (cublasLtHandle_t)backend->getCublasHandle(); // A cublas handle encapsulates an lt handle
|
||||
|
||||
size_t numWorkSpaceElts = 8192; // Allows for cublasLt to perform split-K gemms. This is chosen to be at least
|
||||
// 16 KiB for float16 which is large enough to prevent alloc failed errors
|
||||
size_t workspaceSizeBytes = numWorkSpaceElts * sizeof(T);
|
||||
IPtr<MemoryPiece> workspace = allocator->alloc<T>(numWorkSpaceElts);
|
||||
|
||||
cudaStream_t stream = 0;
|
||||
CUBLAS_CHECK(cublasGetStream(cublasHandle, &stream));
|
||||
|
||||
|
||||
CUBLAS_CHECK(cublasLtAffineTyped(ltHandle,
|
||||
opB,
|
||||
opA,
|
||||
n,
|
||||
m,
|
||||
k,
|
||||
&alpha,
|
||||
B->data<T>(),
|
||||
ldb,
|
||||
A->data<T>(),
|
||||
lda,
|
||||
&beta,
|
||||
C->data<T>(),
|
||||
ldc,
|
||||
bias->data<T>(),
|
||||
workspace->data<T>(),
|
||||
workspaceSizeBytes,
|
||||
do_relu,
|
||||
stream));
|
||||
|
||||
allocator->free(workspace);
|
||||
}
|
||||
|
||||
// This version is needed so that Windows doesn't complain when compiling CUDA < 11. Otherwise, the ifdef could be inside of one
|
||||
// definition of Affine.
|
||||
void Affine(marian::Tensor C,
|
||||
Ptr<Allocator> allocator,
|
||||
const marian::Tensor& A,
|
||||
const marian::Tensor& B,
|
||||
const marian::Tensor& bias,
|
||||
bool transA, bool transB, float beta, float scalar, bool do_relu) {
|
||||
// There is a bug in CUDA 11 where the bias pointer needs to be 8 byte aligned. This bug will be fix in a subsequent release. For now,
|
||||
// we launch a custom epilogue if the bias does not meet the alignment requirement.
|
||||
if(C->type() == Type::float32) {
|
||||
affineTyped<float>(C, allocator, A, B, bias, transA, transB, beta, scalar, do_relu);
|
||||
if((uintptr_t)bias->data<float>() % REQUIRED_BIAS_ALIGNMENT != 0) {
|
||||
BiasAdd(C, bias, do_relu);
|
||||
}
|
||||
#if COMPILE_FP16
|
||||
} else if(C->type() == Type::float16) {
|
||||
affineTyped<half>(C, allocator, A, B, bias, transA, transB, __float2half(beta), __float2half(scalar), do_relu);
|
||||
if((uintptr_t)bias->data<half>() % REQUIRED_BIAS_ALIGNMENT != 0) {
|
||||
BiasAdd(C, bias, do_relu);
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
ABORT("Affine not implemented for type {}", C->type());
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void Affine(marian::Tensor C,
|
||||
Ptr<Allocator> /*allocator*/,
|
||||
const marian::Tensor& A,
|
||||
const marian::Tensor& B,
|
||||
const marian::Tensor& bias,
|
||||
bool transA, bool transB, float beta, float scalar, bool do_relu) {
|
||||
|
||||
if(C->type() == Type::float32) {
|
||||
ProdTyped<float>(C, A, B, transA, transB, beta, scalar);
|
||||
#if COMPILE_FP16
|
||||
} else if(C->type() == Type::float16) {
|
||||
ProdTyped<half>(C, A, B, transA, transB, __float2half(beta), __float2half(scalar));
|
||||
#endif
|
||||
} else {
|
||||
ABORT("Prod not implemented for type {}", C->type());
|
||||
}
|
||||
BiasAdd(C, bias, do_relu);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace marian
|
||||
|
69
src/tensors/gpu/prod.cu
Normal file
69
src/tensors/gpu/prod.cu
Normal file
@ -0,0 +1,69 @@
|
||||
#include <stdint.h>
|
||||
#include "tensors/tensor.h"
|
||||
#include "tensors/gpu/cuda_helpers.h"
|
||||
#include "tensors/gpu/backend.h"
|
||||
|
||||
namespace marian {
|
||||
namespace gpu {
|
||||
|
||||
template <typename T, typename ActFunc>
|
||||
__global__ static void gBiasAddFused(T* tensor, T* bias, size_t tensor_size, size_t bias_size, ActFunc f) {
|
||||
const size_t row_start = blockIdx.x * bias_size;
|
||||
for(int bias_offset = threadIdx.x; bias_offset < bias_size; bias_offset+=blockDim.x) {
|
||||
size_t offset_into_tensor = row_start + bias_offset;
|
||||
if(offset_into_tensor < tensor_size) {
|
||||
T added_bias = tensor[offset_into_tensor] + bias[bias_offset];
|
||||
tensor[offset_into_tensor] = f(added_bias);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct identity {
|
||||
template <typename T>
|
||||
__device__ constexpr T&& operator() (T&& t) const noexcept {
|
||||
return std::forward<T>(t);
|
||||
}
|
||||
};
|
||||
|
||||
struct reluAct {
|
||||
template <typename T>
|
||||
__device__ T operator() (T t) const noexcept {
|
||||
return t > (T) 0? t : (T) 0;
|
||||
}
|
||||
};
|
||||
|
||||
void BiasAdd(marian::Tensor C, const marian::Tensor& bias, bool do_relu) {
|
||||
auto backend = std::static_pointer_cast<gpu::Backend>(C->getBackend());
|
||||
CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
|
||||
|
||||
size_t size = C->shape().elements();
|
||||
size_t bias_size = bias->shape().elements();
|
||||
|
||||
int m = C->shape().elements() / C->shape().back();
|
||||
int n = C->shape().back();
|
||||
|
||||
ABORT_IF(n != bias_size, "The number of elements in the bias must match the number of columns in C");
|
||||
|
||||
int threads_per_block = std::min(MAX_THREADS, n);
|
||||
int blocks = m;
|
||||
|
||||
if(C->type() == Type::float32) {
|
||||
if (do_relu)
|
||||
gBiasAddFused<<<blocks, threads_per_block>>>(C->data<float>(), bias->data<float>(), size, bias_size, reluAct());
|
||||
else
|
||||
gBiasAddFused<<<blocks, threads_per_block>>>(C->data<float>(), bias->data<float>(), size, bias_size, identity());
|
||||
|
||||
#if COMPILE_FP16
|
||||
} else if(C->type() == Type::float16) {
|
||||
if (do_relu)
|
||||
gBiasAddFused<<<blocks, threads_per_block>>>(C->data<half>(), bias->data<half>(), size, bias_size, reluAct());
|
||||
else
|
||||
gBiasAddFused<<<blocks, threads_per_block>>>(C->data<half>(), bias->data<half>(), size, bias_size, identity());
|
||||
#endif
|
||||
} else {
|
||||
ABORT("Prod not implemented for type {}", C->type());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -6,6 +6,21 @@
|
||||
namespace marian {
|
||||
namespace gpu {
|
||||
|
||||
void BiasAdd(marian::Tensor C,
|
||||
const marian::Tensor& bias,
|
||||
bool do_relu = false);
|
||||
|
||||
void Affine(marian::Tensor C,
|
||||
Ptr<Allocator> allocator,
|
||||
const marian::Tensor& A,
|
||||
const marian::Tensor& B,
|
||||
const marian::Tensor& bias,
|
||||
bool transA,
|
||||
bool transB,
|
||||
float beta = 0,
|
||||
float scalar = 1,
|
||||
bool do_relu = false);
|
||||
|
||||
void Prod(marian::Tensor C,
|
||||
const marian::Tensor& A,
|
||||
const marian::Tensor& B,
|
||||
|
@ -106,6 +106,8 @@ DISPATCH8(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bo
|
||||
DISPATCH8(ProdBatched, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
|
||||
DISPATCH9(CSRProd, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float)
|
||||
|
||||
DISPATCH10(Affine, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float, bool)
|
||||
|
||||
DISPATCH2(Softmax, marian::Tensor, marian::Tensor)
|
||||
DISPATCH3(SoftmaxGrad, marian::Tensor, marian::Tensor, marian::Tensor)
|
||||
|
||||
|
@ -32,6 +32,8 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
|
||||
Config::seed = 1234;
|
||||
auto graph = New<ExpressionGraph>();
|
||||
|
||||
graph->setInference(true);
|
||||
graph->setDefaultElementType(floatType);
|
||||
graph->setDevice({0, device});
|
||||
graph->reserveWorkspaceMB(16);
|
||||
@ -539,15 +541,19 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
values.clear();
|
||||
|
||||
std::vector<T> vA({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
std::vector<T> vB({1, 2, 3, 4, 5, 6});
|
||||
std::vector<T> vAff({24, 30, 51, 66, 78, 102, 105, 138});
|
||||
std::vector<T> vB({1, -2, 3, 4, -5, 6});
|
||||
std::vector<T> vAff({-6, 26, -9, 50, -12, 74, -15, 98});
|
||||
std::vector<T> vAffRelu({0, 26, 0, 50, 0, 74, 0, 98});
|
||||
|
||||
auto A = graph->param("A", {4, 3}, inits::fromVector(vA));
|
||||
auto B = graph->param("B", {3, 2}, inits::fromVector(vB));
|
||||
auto C = graph->param("C", {4, 2}, inits::fromValue(2));
|
||||
auto bias = graph->param("C", {1, 2}, inits::fromValue(2));
|
||||
|
||||
auto aff1 = affine(A, B, C);
|
||||
auto aff2 = dot(A, B) + C;
|
||||
auto aff1 = affine(A, B, bias);
|
||||
auto aff2 = dot(A, B) + bias;
|
||||
|
||||
auto affRelu1 = affineWithRelu(A, B, bias);
|
||||
auto affRelu2 = relu(dot(A, B) + bias);
|
||||
|
||||
graph->forward();
|
||||
|
||||
@ -559,6 +565,11 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
CHECK(aff2->shape() == aff1->shape());
|
||||
aff2->val()->get(values2);
|
||||
CHECK(values2 == values);
|
||||
|
||||
affRelu1->val()->get(values);
|
||||
affRelu2->val()->get(values2);
|
||||
CHECK(values2 == vAffRelu);
|
||||
CHECK(values2 == values);
|
||||
}
|
||||
|
||||
SECTION("repeat") {
|
||||
|
Loading…
Reference in New Issue
Block a user