broadcasting bdot

This commit is contained in:
Marcin Junczys-Dowmunt 2021-06-07 09:14:39 -07:00
parent 2c1b16f43e
commit 77c0cac1f2
4 changed files with 181 additions and 36 deletions

View File

@ -529,11 +529,27 @@ public:
shapeB.set(-1, b->shape()[-2]);
}
Shape outShape = shapeA;
outShape.set(-1, shapeB[-1]);
ABORT_IF(shapeA[-1] != shapeB[-2],
"Batched matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
return outShape;
"Batched matrix product requires inner dimensions to match in {}{} * {}{}",
std::string(shapeA), transA, std::string(shapeB), transB);
// create shapes for batch dimensions only
auto shapeBatchA = shapeA;
shapeBatchA.set(-1, 1);
shapeBatchA.set(-2, 1);
auto shapeBatchB = shapeB;
shapeBatchB.set(-1, 1);
shapeBatchB.set(-2, 1);
// broadcast batch dimensions
auto shapeOut = Shape::broadcast({shapeBatchA, shapeBatchB});
// set non-batch dimensions in output
shapeOut.set(-1, shapeA[-2]);
shapeOut.set(-2, shapeB[-1]);
return shapeOut;
}
NodeOps forwardOps() override {

View File

@ -93,31 +93,58 @@ void ProdBatched(marian::Tensor C,
#if BLAS_FOUND
float alpha = scalar;
size_t batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]);
size_t batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]);
// determine meta-shape of bdot operation. Essentially treat the last two dimensions as single elements
// such that (..., m, k) x (..., k, n) -> (..., m, n) where ... is a broadcastable shape as in element-wise kernels.
size_t m = A->shape()[-2];
size_t k = A->shape()[-1];
auto aShape = A->shape();
auto bShape = B->shape();
// make sure both shape have the same number of dimensions via broadcasting
size_t maxLength = std::max(aShape.size(), bShape.size());
if(aShape.size() != bShape.size()) {
Shape ones(std::vector<int>(maxLength, 1));
aShape = Shape::broadcast({aShape, ones});
bShape = Shape::broadcast({bShape, ones});
}
// Create meta-shapes without last 2 dimensions
Shape aShapeMeta, bShapeMeta, cShapeMeta;
aShapeMeta.resize(maxLength - 2);
bShapeMeta.resize(maxLength - 2);
for(size_t i = 0; i < maxLength - 2; ++i) {
aShapeMeta.set(i, aShape[i]);
bShapeMeta.set(i, bShape[i]);
}
cShapeMeta = Shape::broadcast({aShapeMeta, bShapeMeta});
size_t m = aShape[-2];
size_t k = aShape[-1];
if(transA)
std::swap(m, k);
size_t l = B->shape()[-2];
size_t n = B->shape()[-1];
size_t l = bShape[-2];
size_t n = bShape[-1];
if(transB)
std::swap(l, n);
size_t lda = A->shape()[-1];
size_t ldb = B->shape()[-1];
size_t ldc = B->shape()[-1];
size_t lda = aShape[-1];
size_t ldb = bShape[-1];
size_t ldc = bShape[-1];
if(transB)
ldc = B->shape()[-2];
ldc = bShape[-2];
auto strideB = batchB == 1 ? 0 : n * k;
auto strideA = batchA == 1 ? 0 : m * k;
auto strideA = m * k;
auto strideB = n * k;
auto strideC = n * m;
auto batchC = std::max(batchA, batchB);
auto batchC = cShapeMeta.elements();
// Convert to functional shapes to be able to map dimensions. @TODO merge this
functional::Shape aShapeMetaF = aShapeMeta;
functional::Shape bShapeMetaF = bShapeMeta;
functional::Shape cShapeMetaF = cShapeMeta;
#if MKL_FOUND
CBLAS_TRANSPOSE transA_forarr = CblasNoTrans;
CBLAS_TRANSPOSE transB_forarr = CblasNoTrans;
@ -156,9 +183,14 @@ void ProdBatched(marian::Tensor C,
// This loop initializes the array pointers in the same way as the for loop
// in the normal sgemm version a few lines below
functional::Array<int, functional::Shape::size()> dims;
for(size_t i = 0; i < batchC; ++i) {
a_array[i] = A->data() + (i % batchA) * strideA;
b_array[i] = B->data() + (i % batchB) * strideB;
cShapeMetaF.dims(i, dims);
auto aIndex = aShapeMetaF.bindex(dims);
auto bIndex = bShapeMetaF.bindex(dims);
a_array[i] = A->data() + aIndex * strideA;
b_array[i] = B->data() + bIndex * strideB;
c_array[i] = C->data() + i * strideC;
}
cblas_sgemm_batch (CblasRowMajor,
@ -178,16 +210,21 @@ void ProdBatched(marian::Tensor C,
group_count,
&group_size[0]);
#else
functional::Array<int, functional::Shape::size()> dims;
for(size_t i = 0; i < batchC; ++i) {
cShapeMetaF.dims(i, dims);
auto aIndex = aShapeMetaF.bindex(dims);
auto bIndex = bShapeMetaF.bindex(dims);
sgemm(transA,
transB,
(int)m,
(int)n,
(int)k,
alpha,
A->data() + (i % batchA) * strideA,
A->data() + aIndex * strideA,
(int)lda,
B->data() + (i % batchB) * strideB,
B->data() + bIndex * strideB,
(int)ldb,
beta,
C->data() + i * strideC,

View File

@ -347,25 +347,46 @@ void ProdBatchedTyped(marian::Tensor C,
CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no));
ComputeType alpha = scalar;
int batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]);
int batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]);
// determine meta-shape of bdot operation. Essentially treat the last two dimensions as single elements
// such that (..., m, k) x (..., k, n) -> (..., m, n) where ... is a broadcastable shape as in element-wise kernels.
int m = A->shape()[-2];
int k = A->shape()[-1];
auto aShape = A->shape();
auto bShape = B->shape();
// make sure both shape have the same number of dimensions via broadcasting
size_t maxLength = std::max(aShape.size(), bShape.size());
if(aShape.size() != bShape.size()) {
Shape ones(std::vector<int>(maxLength, 1));
aShape = Shape::broadcast({aShape, ones});
bShape = Shape::broadcast({bShape, ones});
}
// Create meta-shapes without last 2 dimensions
Shape aShapeMeta, bShapeMeta, cShapeMeta;
aShapeMeta.resize(maxLength - 2);
bShapeMeta.resize(maxLength - 2);
for(size_t i = 0; i < maxLength - 2; ++i) {
aShapeMeta.set(i, aShape[i]);
bShapeMeta.set(i, bShape[i]);
}
cShapeMeta = Shape::broadcast({aShapeMeta, bShapeMeta});
size_t m = aShape[-2];
size_t k = aShape[-1];
if(transA)
std::swap(m, k);
int l = B->shape()[-2];
int n = B->shape()[-1];
size_t l = bShape[-2];
size_t n = bShape[-1];
if(transB)
std::swap(l, n);
int lda = A->shape()[-1];
int ldb = B->shape()[-1];
int ldc = B->shape()[-1];
size_t lda = aShape[-1];
size_t ldb = bShape[-1];
size_t ldc = bShape[-1];
if(transB)
ldc = B->shape()[-2];
ldc = bShape[-2];
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
@ -374,18 +395,29 @@ void ProdBatchedTyped(marian::Tensor C,
auto cublasHandle = backend->getCublasHandle();
auto compute = backend->getCudaComputeCapability();
auto strideA = batchA == 1 ? 0 : m * k;
auto strideB = batchB == 1 ? 0 : n * k;
auto strideA = m * k;
auto strideB = n * k;
auto strideC = n * m;
auto batchC = std::max(batchA, batchB);
auto batchC = cShapeMeta.elements();
// Convert to functional shapes to be able to map dimensions. @TODO merge this
functional::Shape aShapeMetaF = aShapeMeta;
functional::Shape bShapeMetaF = bShapeMeta;
functional::Shape cShapeMetaF = cShapeMeta;
std::vector<const ElementType*> aptr;
std::vector<const ElementType*> bptr;
std::vector<ElementType*> cptr;
functional::Array<int, functional::Shape::size()> dims;
for(int i = 0; i < batchC; i++) {
aptr.push_back(A->data<ElementType>() + (i % batchA) * strideA);
bptr.push_back(B->data<ElementType>() + (i % batchB) * strideB);
cShapeMetaF.dims(i, dims);
auto aIndex = aShapeMetaF.bindex(dims);
auto bIndex = bShapeMetaF.bindex(dims);
aptr.push_back(A->data<ElementType>() + aIndex * strideA);
bptr.push_back(B->data<ElementType>() + bIndex * strideB);
cptr.push_back(C->data<ElementType>() + i * strideC);
}

View File

@ -615,6 +615,66 @@ void tests(DeviceType device, Type floatType = Type::float32) {
CHECK(values2 == values);
}
SECTION("bdot") {
graph->clear();
values.clear();
std::vector<T> vA({ 1, 2,
3, 4,
5, 6,
7, 8});
std::vector<T> vB({ 1, 2,
3, 4,
5, 6,
7, 8,
9, 10,
11, 12});
std::vector<T> vC({ 7, 10,
15, 22,
19, 22,
43, 50,
31, 34,
71, 78,
23, 34,
31, 46,
67, 78,
91, 106,
111, 122,
151, 166});
std::vector<T> vCt({ 5, 11,
11, 25,
17, 23,
39, 53,
29, 35,
67, 81,
17, 39,
23, 53,
61, 83,
83, 113,
105, 127,
143, 173});
auto A = graph->param("A", {2, 1, 2, 2}, inits::fromVector(vA));
auto B = graph->param("B", {1, 3, 2, 2}, inits::fromVector(vB));
auto C = bdot(A, B, /*transA=*/false, /*transB=*/false);
auto Ct = bdot(A, B, /*transA=*/false, /*transB=*/true);
graph->forward();
CHECK(C->shape() == Shape({2, 3, 2, 2}));
CHECK(Ct->shape() == Shape({2, 3, 2, 2}));
C->val()->get(values);
CHECK(vC == values);
Ct->val()->get(values);
CHECK(vCt == values);
}
SECTION("repeat") {
graph->clear();
values.clear();