Adds better Affine support for GPUs when using CUDA 11. Introduces a new bias addition kernel for CUDA < 11 (#778)

Co-authored-by: Marcin Junczys-Dowmunt <marcinjd@microsoft.com>
This commit is contained in:
rhenry-nv 2021-04-08 21:46:27 -07:00 committed by GitHub
parent 0223ce90b1
commit fddd0e0661
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 551 additions and 51 deletions

View File

@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased]
### Added
- Adds custom bias epilogue kernel.
- Adds support for fusing relu and bias addition into gemms when using cuda 11.
- Better suppression of unwanted output symbols, specifically "\n" from SentencePiece with byte-fallback. Can be deactivated with --allow-special
- Display decoder time statistics with marian-decoder --stat-freq 10 ...
- Support for MS-internal binary shortlist

View File

@ -347,8 +347,20 @@ if(CUDA_FOUND)
endif()
message(STATUS "Found CUDA libraries: ${CUDA_LIBS}")
else(USE_STATIC_LIBS)
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
# We actually only need cublasLt here after cuda 11. Marian will work fine without it pre cuda 11. We want to force CMake to use the cublas
# version that ships with CUDA 11 so we force the search to occur inside of the cuda toolkit directory.
set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
if ((CUDA_VERSION VERSION_EQUAL "11.0" OR CUDA_VERSION VERSION_GREATER "11.0"))
find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64 NO_DEFAULT_PATH)
if(NOT CUDA_cublasLt_LIBRARY)
message(FATAL_ERROR "cuBLASLt library not found")
endif()
set(EXT_LIBS ${EXT_LIBS} ${CUDA_cublasLt_LIBRARY})
set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_cublasLt_LIBRARY})
endif()
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
message(STATUS "Found CUDA libraries: ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}")
message(STATUS "Found CUDA libraries: ${CUDA_LIBS}")
endif(USE_STATIC_LIBS)
if(USE_CUDNN)

View File

@ -175,6 +175,7 @@ if(CUDA_FOUND)
tensors/gpu/device.cu
tensors/gpu/algorithm.cu
tensors/gpu/prod.cpp
tensors/gpu/prod.cu
tensors/gpu/prod_sparse.cpp
tensors/gpu/topk.cu
tensors/gpu/element.cu

View File

@ -1,4 +1,5 @@
#include "graph/expression_operators.h"
#include "common/definitions.h"
#include "layers/constructors.h"
#include "graph/node_operators.h"
@ -518,7 +519,7 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) {
return Expression<DotBatchedNodeOp>(a, b, transA, transB, scale);
}
static Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
// general version, MKL, CBlas or CUDA
int rows = a->shape().elements() / a->shape()[-1];
@ -577,6 +578,15 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
}
}
Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
auto graph = a->graph();
if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu)
return Expression<AffineWithReluNodeOp>(a, b, bias, transA, transB, scale);
else
return relu(affine(a, b, bias, transA, transB, scale));
}
// @TODO: Not a great place to check this
#if CUDA_VERSION < 11000
// multiply a CSR matrix A with a matrix B

View File

@ -488,11 +488,21 @@ Expr bdot(Expr a,
*/
Expr affine(Expr a,
Expr b,
Expr c,
Expr bias,
bool transA = false,
bool transB = false,
float scalar = 1.f);
/**
* As above, but efficiently applies relu transformation to output. For inference only.
*/
Expr affineWithRelu(Expr a,
Expr b,
Expr bias,
bool transA = false,
bool transB = false,
float scalar = 1.f);
/**
* Computes the dot product of CSR-tensor @p A with @p B.
*/

View File

@ -266,17 +266,18 @@ public:
NodeOps forwardOps() override {
using namespace functional;
return {
NodeOp(
Prod(val_,
child(0)->val(),
child(1)->val(),
transA_,
transB_,
0.f,
scalar_);
Prod(val_, child(3)->val(), child(2)->val(), false, false, 1.f, 1.f))
NodeOp(Affine(val_,
graph()->allocator(),
child(0)->val(),
child(1)->val(),
child(2)->val(),
transA_,
transB_,
0.f,
scalar_,
/*doRelu=*/false))
};
}
@ -323,8 +324,7 @@ public:
false,
1.0,
scalar_, computeTypeB)),
NodeOp(Prod(
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
if(transA_ && !transB_)
@ -343,8 +343,7 @@ public:
false,
1.0,
scalar_, computeTypeB)),
NodeOp(Prod(
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
if(transA_ && transB_)
@ -363,8 +362,7 @@ public:
true,
1.0,
scalar_, computeTypeB)),
NodeOp(Prod(
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
return {
@ -382,8 +380,7 @@ public:
false,
1.0,
scalar_, computeTypeB)),
NodeOp(Prod(
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
}
@ -414,6 +411,97 @@ public:
};
class AffineWithReluNodeOp : public NaryNodeOp {
private:
friend class SerializationHelpers;
bool transA_;
bool transB_;
float scalar_;
public:
AffineWithReluNodeOp(Expr a,
Expr b,
Expr bias,
bool transA,
bool transB,
float scalar)
: NaryNodeOp({a, b, bias}, newShape(a, b, transA, transB)),
transA_(transA),
transB_(transB),
scalar_(scalar) {
ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu,
"AffineWithReluNodeOp currently only supported for inference on GPU");
}
Shape newShape(Expr a, Expr b, bool transA, bool transB) {
auto shapeA = a->shape();
if(transA) {
shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
}
auto shapeB = b->shape();
if(transB) {
shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]);
shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]);
}
Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"Matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
return outShape;
}
NodeOps forwardOps() override {
ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu,
"AffineWithReluNodeOp currently only supported for inference on GPU");
return {
NodeOp(Affine(val_,
graph()->allocator(),
child(0)->val(),
child(1)->val(),
child(2)->val(),
transA_,
transB_,
0.f,
scalar_,
/*doRelu=*/true))
};
}
NodeOps backwardOps() override {
ABORT("AffineWithReluNodeOp cannot be used for training??");
return {};
}
const std::string type() override { return "affineWithRelu"; }
virtual size_t hash() override {
size_t seed = NaryNodeOp::hash();
util::hash_combine(seed, transA_);
util::hash_combine(seed, transB_);
util::hash_combine(seed, scalar_);
return seed;
}
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<AffineWithReluNodeOp>(node);
if(!cnode)
return false;
if(transA_ != cnode->transA_)
return false;
if(transB_ != cnode->transB_)
return false;
if(scalar_ != cnode->scalar_)
return false;
return true;
}
};
class DotBatchedNodeOp : public NaryNodeOp {
private:
friend class SerializationHelpers;

View File

@ -1,5 +1,7 @@
#pragma once
#include "common/definitions.h"
#include "graph/expression_operators.h"
#include "marian.h"
#include "data/shortlist.h"
@ -168,22 +170,37 @@ public:
// --- a few layers with built-in parameters created on the fly, without proper object
// @TODO: change to a proper layer object
static inline std::function<Expr(Expr)> activationByName(const std::string& actName) {
if (actName == "relu")
return (ActivationFunction*)relu;
else if (actName == "swish")
return (ActivationFunction*)swish;
else if (actName == "gelu")
return (ActivationFunction*)gelu;
else if (actName == "") // return identity function if activation name is empty
return [](Expr x) { return x; };
ABORT("Invalid activation name '{}'", actName);
}
// like affine() but with built-in parameters, activation, and dropout
static inline Expr denseInline(Expr x,
std::string prefix,
std::string suffix,
int outDim,
Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
const std::function<Expr(Expr)>& actFn = nullptr,
std::string actName = "",
float dropProb = 0.0f) {
auto graph = x->graph();
auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, inits::glorotUniform());
auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros());
x = affine(x, W, b);
if(actFn)
x = actFn(x);
if(actName == "relu") {
x = affineWithRelu(x, W, b); // speed optimization for inference, @TODO: handle better in future layer framework
} else {
x = affine(x, W, b);
x = activationByName(actName)(x);
}
x = dropout(x, dropProb); // @TODO: check for infernce?
return x;
}

View File

@ -170,7 +170,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
/*suffix=*/"1",
ffnDim,
inits::glorotUniform(),
(ActivationFunction*)relu,
"relu",
ffnDropProb);
f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
// add & norm

View File

@ -396,18 +396,6 @@ public:
opt<int>("transformer-heads"), /*cache=*/false);
}
static inline
std::function<Expr(Expr)> activationByName(const std::string& actName)
{
if (actName == "relu")
return (ActivationFunction*)relu;
else if (actName == "swish")
return (ActivationFunction*)swish;
else if (actName == "gelu")
return (ActivationFunction*)gelu;
ABORT("Invalid activation name '{}'", actName);
}
Expr LayerFFN(std::string prefix, Expr input) const {
int dimModel = input->shape()[-1];
@ -415,9 +403,9 @@ public:
auto opsPre = opt<std::string>("transformer-preprocess");
auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb);
auto actName = opt<std::string>("transformer-ffn-activation");
int dimFfn = opt<int>("transformer-dim-ffn");
int depthFfn = opt<int>("transformer-ffn-depth");
auto actFn = activationByName(opt<std::string>("transformer-ffn-activation"));
float ffnDropProb
= inference_ ? 0 : opt<float>("transformer-dropout-ffn");
@ -427,12 +415,11 @@ public:
// the stack of FF layers
for(int i = 1; i < depthFfn; ++i)
output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, initFn, actFn, ffnDropProb);
output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, initFn, actName, ffnDropProb);
output = denseInline(output, prefix, /*suffix=*/std::to_string(depthFfn), dimModel, initFn);
auto opsPost = opt<std::string>("transformer-postprocess");
output
= postProcess(prefix + "_ffn", opsPost, output, input, dropProb);
output = postProcess(prefix + "_ffn", opsPost, output, input, dropProb);
return output;
}
@ -450,21 +437,21 @@ public:
// FFN
int dimAan = opt<int>("transformer-dim-aan");
int depthAan = opt<int>("transformer-aan-depth");
auto actFn = activationByName(opt<std::string>("transformer-aan-activation"));
auto actName = opt<std::string>("transformer-aan-activation");
float aanDropProb = inference_ ? 0 : opt<float>("transformer-dropout-ffn");
auto initFn = inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f);
// the stack of AAN layers
for(int i = 1; i < depthAan; ++i)
y = denseInline(y, prefix, /*suffix=*/std::to_string(i), dimAan, initFn, actFn, aanDropProb);
y = denseInline(y, prefix, /*suffix=*/std::to_string(i), dimAan, initFn, actName, aanDropProb);
if(y->shape()[-1] != dimModel) // bring it back to the desired dimension if needed
y = denseInline(y, prefix, std::to_string(depthAan), dimModel, initFn);
bool noGate = opt<bool>("transformer-aan-nogate");
if(!noGate) {
auto gi = denseInline(x, prefix, /*suffix=*/"i", dimModel, initFn, (ActivationFunction*)sigmoid);
auto gf = denseInline(y, prefix, /*suffix=*/"f", dimModel, initFn, (ActivationFunction*)sigmoid);
auto gi = denseInline(x, prefix, /*suffix=*/"i", dimModel, initFn, "sigmoid");
auto gf = denseInline(y, prefix, /*suffix=*/"f", dimModel, initFn, "sigmoid");
y = gi * x + gf * y;
}

View File

@ -212,6 +212,23 @@ void ProdWithBias(marian::Tensor C,
cpu::integer::AddBias(C, bias);
}
void Affine(marian::Tensor C,
Ptr<Allocator> /*allocator*/,
const marian::Tensor& A,
const marian::Tensor& B,
const marian::Tensor& bias,
bool transA,
bool transB,
float beta,
float scalar,
bool reluPostprocess) {
using namespace functional;
ProdWithBias(C, A, B, bias, transA, transB, beta, scalar);
if(reluPostprocess)
cpu::Element(_1 = ReLU(_1), C); // @TODO: also fuse with AddBias
}
void CSRProd(marian::Tensor C,
Ptr<Allocator> /*allocator*/,
const marian::Tensor& S_values,

View File

@ -152,6 +152,30 @@
cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9); \
}
#define DISPATCH10( \
Function, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10) \
namespace gpu { \
void Function(Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10); \
} \
namespace cpu { \
void Function(Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10); \
} \
static inline void Function(Arg1 arg1, \
Arg2 arg2, \
Arg3 arg3, \
Arg4 arg4, \
Arg5 arg5, \
Arg6 arg6, \
Arg7 arg7, \
Arg8 arg8, \
Arg9 arg9, \
Arg10 arg10) { \
if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \
gpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10); \
else \
cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10); \
}
#else
#define DISPATCH1(Function, Arg1) \
@ -248,4 +272,22 @@
cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9); \
}
#define DISPATCH10( \
Function, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10) \
namespace cpu { \
void Function(Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10); \
} \
static inline void Function(Arg1 arg1, \
Arg2 arg2, \
Arg3 arg3, \
Arg4 arg4, \
Arg5 arg5, \
Arg6 arg6, \
Arg7 arg7, \
Arg8 arg8, \
Arg9 arg9, \
Arg10 arg10) { \
cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10); \
}
#endif

View File

@ -11,10 +11,34 @@
#include "tensors/gpu/cuda_helpers.h"
// clang-format on
#if CUDA_VERSION >= 11000
#include <cublasLt.h>
#endif
namespace marian {
namespace gpu {
// It seems that the bias must be 8 byte aligned for the cublasLt epilogue to work. Therefore,
// if the bias pointer is not 8 byte aligned, we do a normal matmul in cublasLt and invoke a
// custom epilogue kernel.
static constexpr int REQUIRED_BIAS_ALIGNMENT = 8;
// Used to set preferences for cublasLt to filter out algos if matrices to not meet default 256 byte alignment
int getAlignmentUpTo256(const void *ptr) {
uintptr_t addr = (uintptr_t)ptr;
int trailingZeros = 0;
for(int shiftAmt = 8, mask = 0xFF; shiftAmt > 0; shiftAmt /= 2, mask >>=shiftAmt) {
if ((addr & mask) == 0) {
trailingZeros += shiftAmt;
addr >>= shiftAmt;
}
}
return std::min(256, 1 << trailingZeros);
}
// The explicit version of matmult like cublasGemmEx choose their math mode based on the algorithm that
// has been passed into the function call and seem to ignore setMathMode. Here we query the used math mode
// to choose the algorithm.
@ -412,5 +436,198 @@ void ProdBatched(marian::Tensor C,
}
}
#if CUDA_VERSION >= 11000 // Earlier versions of cublasLT do not support bias addition for fp32 and fp16.
static cublasStatus_t cublasLtAffineHelper(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB,
cudaDataType matrixType,
int m, int n, int k, const void *alpha, const void *A, int lda, const void *B,
int ldb, const void *beta, void *C, int ldc, const void* bias,
void* workspace, size_t workspaceSize, bool do_relu, cudaStream_t stream) {
cublasLtMatmulDesc_t operationDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasLtMatmulPreference_t preference = NULL;
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = do_relu? CUBLASLT_EPILOGUE_RELU_BIAS: CUBLASLT_EPILOGUE_BIAS;
cublasComputeType_t computeType = matrixType == CUDA_R_32F? CUBLAS_COMPUTE_32F_FAST_16F: CUBLAS_COMPUTE_16F;
// If the bias is not aligned, just matmul and invoke custom epilogue later.
// cublas fails with a misalignment error if this condition is not true.
if((uintptr_t)bias % REQUIRED_BIAS_ALIGNMENT != 0) {
epilogue = CUBLASLT_EPILOGUE_DEFAULT;
}
CUBLAS_CHECK(cublasLtMatmulDescCreate(&operationDesc, computeType, matrixType));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transA, sizeof(transA)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transB, sizeof(transB)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&Adesc, matrixType, transA == CUBLAS_OP_N ? m : k, transA == CUBLAS_OP_N ? k : m, lda));
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&Bdesc, matrixType, transB == CUBLAS_OP_N ? k : n, transB == CUBLAS_OP_N ? n : k, ldb));
CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&Cdesc, matrixType, m, n, ldc));
// I think we need to do this since we can slice matrices...
// The allocator always allocates on 256 byte boundaries but we have no guarantees about the alignment of a matrix slice so we filter out
// algorithms that would not work with matrices not aligned to 256 bytes.
int alignmentA = getAlignmentUpTo256(A);
int alignmentB = getAlignmentUpTo256(B);
int alignmentC = getAlignmentUpTo256(C);
CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &alignmentA, sizeof(alignmentA)));
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &alignmentB, sizeof(alignmentB)));
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &alignmentC, sizeof(alignmentC)));
CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &alignmentC, sizeof(alignmentC)));
CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, preference, 1, &heuristicResult, &returnedResults));
cublasStatus_t opStatus = cublasLtMatmul(ltHandle, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc,
&heuristicResult.algo, workspace, workspaceSize, stream);
if (preference) CUBLAS_CHECK(cublasLtMatmulPreferenceDestroy(preference));
if (Cdesc) CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(Cdesc));
if (Bdesc) CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(Bdesc));
if (Adesc) CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(Adesc));
if (operationDesc) CUBLAS_CHECK(cublasLtMatmulDescDestroy(operationDesc));
return opStatus;
}
static cublasStatus_t cublasLtAffineTyped(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k, const half *alpha, const half *A, int lda, const half *B,
int ldb, const half *beta, half *C, int ldc, const half* bias,
half* workspace, size_t workspaceSizeBytes, bool do_relu, cudaStream_t stream) {
return cublasLtAffineHelper(ltHandle, transA, transB, CUDA_R_16F, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, bias,
workspace, workspaceSizeBytes, do_relu, stream);
}
static cublasStatus_t cublasLtAffineTyped(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB,
int m, int n, int k, const float *alpha, const float *A, int lda, const float *B,
int ldb, const float *beta, float *C, int ldc, const float* bias,
float* workspace, size_t workspaceSizeBytes,bool do_relu, cudaStream_t stream) {
return cublasLtAffineHelper(ltHandle, transA, transB, CUDA_R_32F, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, bias,
workspace, workspaceSizeBytes, do_relu, stream);
}
template <typename T>
void affineTyped(marian::Tensor C, Ptr<Allocator> allocator, const marian::Tensor& A, const marian::Tensor& B, const marian::Tensor& bias,
bool transA, bool transB, T beta, T scalar, bool do_relu) {
CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no));
T alpha = scalar;
int m = A->shape().elements() / A->shape().back();
int k = A->shape().back();
if(transA)
std::swap(m, k);
int l = B->shape().elements() / B->shape().back();
int n = B->shape().back();
if(transB)
std::swap(l, n);
int lda = A->shape().back();
int ldb = B->shape().back();
int ldc = B->shape().back();
size_t bias_size = bias->shape().elements();
ABORT_IF(n != bias_size, "The number of elements in the bias must match the number of columns in C");
if(transB)
ldc = B->shape().elements() / B->shape().back();
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
auto backend = std::static_pointer_cast<gpu::Backend>(C->getBackend());
auto cublasHandle = backend->getCublasHandle();
auto ltHandle = (cublasLtHandle_t)backend->getCublasHandle(); // A cublas handle encapsulates an lt handle
size_t numWorkSpaceElts = 8192; // Allows for cublasLt to perform split-K gemms. This is chosen to be at least
// 16 KiB for float16 which is large enough to prevent alloc failed errors
size_t workspaceSizeBytes = numWorkSpaceElts * sizeof(T);
IPtr<MemoryPiece> workspace = allocator->alloc<T>(numWorkSpaceElts);
cudaStream_t stream = 0;
CUBLAS_CHECK(cublasGetStream(cublasHandle, &stream));
CUBLAS_CHECK(cublasLtAffineTyped(ltHandle,
opB,
opA,
n,
m,
k,
&alpha,
B->data<T>(),
ldb,
A->data<T>(),
lda,
&beta,
C->data<T>(),
ldc,
bias->data<T>(),
workspace->data<T>(),
workspaceSizeBytes,
do_relu,
stream));
allocator->free(workspace);
}
// This version is needed so that Windows doesn't complain when compiling CUDA < 11. Otherwise, the ifdef could be inside of one
// definition of Affine.
void Affine(marian::Tensor C,
Ptr<Allocator> allocator,
const marian::Tensor& A,
const marian::Tensor& B,
const marian::Tensor& bias,
bool transA, bool transB, float beta, float scalar, bool do_relu) {
// There is a bug in CUDA 11 where the bias pointer needs to be 8 byte aligned. This bug will be fix in a subsequent release. For now,
// we launch a custom epilogue if the bias does not meet the alignment requirement.
if(C->type() == Type::float32) {
affineTyped<float>(C, allocator, A, B, bias, transA, transB, beta, scalar, do_relu);
if((uintptr_t)bias->data<float>() % REQUIRED_BIAS_ALIGNMENT != 0) {
BiasAdd(C, bias, do_relu);
}
#if COMPILE_FP16
} else if(C->type() == Type::float16) {
affineTyped<half>(C, allocator, A, B, bias, transA, transB, __float2half(beta), __float2half(scalar), do_relu);
if((uintptr_t)bias->data<half>() % REQUIRED_BIAS_ALIGNMENT != 0) {
BiasAdd(C, bias, do_relu);
}
#endif
} else {
ABORT("Affine not implemented for type {}", C->type());
}
}
#else
void Affine(marian::Tensor C,
Ptr<Allocator> /*allocator*/,
const marian::Tensor& A,
const marian::Tensor& B,
const marian::Tensor& bias,
bool transA, bool transB, float beta, float scalar, bool do_relu) {
if(C->type() == Type::float32) {
ProdTyped<float>(C, A, B, transA, transB, beta, scalar);
#if COMPILE_FP16
} else if(C->type() == Type::float16) {
ProdTyped<half>(C, A, B, transA, transB, __float2half(beta), __float2half(scalar));
#endif
} else {
ABORT("Prod not implemented for type {}", C->type());
}
BiasAdd(C, bias, do_relu);
}
#endif
} // namespace gpu
} // namespace marian

69
src/tensors/gpu/prod.cu Normal file
View File

@ -0,0 +1,69 @@
#include <stdint.h>
#include "tensors/tensor.h"
#include "tensors/gpu/cuda_helpers.h"
#include "tensors/gpu/backend.h"
namespace marian {
namespace gpu {
template <typename T, typename ActFunc>
__global__ static void gBiasAddFused(T* tensor, T* bias, size_t tensor_size, size_t bias_size, ActFunc f) {
const size_t row_start = blockIdx.x * bias_size;
for(int bias_offset = threadIdx.x; bias_offset < bias_size; bias_offset+=blockDim.x) {
size_t offset_into_tensor = row_start + bias_offset;
if(offset_into_tensor < tensor_size) {
T added_bias = tensor[offset_into_tensor] + bias[bias_offset];
tensor[offset_into_tensor] = f(added_bias);
}
}
}
struct identity {
template <typename T>
__device__ constexpr T&& operator() (T&& t) const noexcept {
return std::forward<T>(t);
}
};
struct reluAct {
template <typename T>
__device__ T operator() (T t) const noexcept {
return t > (T) 0? t : (T) 0;
}
};
void BiasAdd(marian::Tensor C, const marian::Tensor& bias, bool do_relu) {
auto backend = std::static_pointer_cast<gpu::Backend>(C->getBackend());
CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
size_t size = C->shape().elements();
size_t bias_size = bias->shape().elements();
int m = C->shape().elements() / C->shape().back();
int n = C->shape().back();
ABORT_IF(n != bias_size, "The number of elements in the bias must match the number of columns in C");
int threads_per_block = std::min(MAX_THREADS, n);
int blocks = m;
if(C->type() == Type::float32) {
if (do_relu)
gBiasAddFused<<<blocks, threads_per_block>>>(C->data<float>(), bias->data<float>(), size, bias_size, reluAct());
else
gBiasAddFused<<<blocks, threads_per_block>>>(C->data<float>(), bias->data<float>(), size, bias_size, identity());
#if COMPILE_FP16
} else if(C->type() == Type::float16) {
if (do_relu)
gBiasAddFused<<<blocks, threads_per_block>>>(C->data<half>(), bias->data<half>(), size, bias_size, reluAct());
else
gBiasAddFused<<<blocks, threads_per_block>>>(C->data<half>(), bias->data<half>(), size, bias_size, identity());
#endif
} else {
ABORT("Prod not implemented for type {}", C->type());
}
}
}
}

View File

@ -6,6 +6,21 @@
namespace marian {
namespace gpu {
void BiasAdd(marian::Tensor C,
const marian::Tensor& bias,
bool do_relu = false);
void Affine(marian::Tensor C,
Ptr<Allocator> allocator,
const marian::Tensor& A,
const marian::Tensor& B,
const marian::Tensor& bias,
bool transA,
bool transB,
float beta = 0,
float scalar = 1,
bool do_relu = false);
void Prod(marian::Tensor C,
const marian::Tensor& A,
const marian::Tensor& B,

View File

@ -106,6 +106,8 @@ DISPATCH8(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bo
DISPATCH8(ProdBatched, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
DISPATCH9(CSRProd, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float)
DISPATCH10(Affine, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float, bool)
DISPATCH2(Softmax, marian::Tensor, marian::Tensor)
DISPATCH3(SoftmaxGrad, marian::Tensor, marian::Tensor, marian::Tensor)

View File

@ -32,6 +32,8 @@ void tests(DeviceType device, Type floatType = Type::float32) {
Config::seed = 1234;
auto graph = New<ExpressionGraph>();
graph->setInference(true);
graph->setDefaultElementType(floatType);
graph->setDevice({0, device});
graph->reserveWorkspaceMB(16);
@ -539,15 +541,19 @@ void tests(DeviceType device, Type floatType = Type::float32) {
values.clear();
std::vector<T> vA({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
std::vector<T> vB({1, 2, 3, 4, 5, 6});
std::vector<T> vAff({24, 30, 51, 66, 78, 102, 105, 138});
std::vector<T> vB({1, -2, 3, 4, -5, 6});
std::vector<T> vAff({-6, 26, -9, 50, -12, 74, -15, 98});
std::vector<T> vAffRelu({0, 26, 0, 50, 0, 74, 0, 98});
auto A = graph->param("A", {4, 3}, inits::fromVector(vA));
auto B = graph->param("B", {3, 2}, inits::fromVector(vB));
auto C = graph->param("C", {4, 2}, inits::fromValue(2));
auto bias = graph->param("C", {1, 2}, inits::fromValue(2));
auto aff1 = affine(A, B, C);
auto aff2 = dot(A, B) + C;
auto aff1 = affine(A, B, bias);
auto aff2 = dot(A, B) + bias;
auto affRelu1 = affineWithRelu(A, B, bias);
auto affRelu2 = relu(dot(A, B) + bias);
graph->forward();
@ -559,6 +565,11 @@ void tests(DeviceType device, Type floatType = Type::float32) {
CHECK(aff2->shape() == aff1->shape());
aff2->val()->get(values2);
CHECK(values2 == values);
affRelu1->val()->get(values);
affRelu2->val()->get(values2);
CHECK(values2 == vAffRelu);
CHECK(values2 == values);
}
SECTION("repeat") {