From 77c0cac1f21f255e7f78bbf7b0d2b138afb3743a Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Mon, 7 Jun 2021 09:14:39 -0700 Subject: [PATCH] broadcasting bdot --- src/graph/node_operators_binary.h | 24 ++++++++-- src/tensors/cpu/prod.cpp | 71 +++++++++++++++++++++++------- src/tensors/gpu/prod.cpp | 62 +++++++++++++++++++------- src/tests/units/operator_tests.cpp | 60 +++++++++++++++++++++++++ 4 files changed, 181 insertions(+), 36 deletions(-) diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 91fc29da..bd52103a 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -529,11 +529,27 @@ public: shapeB.set(-1, b->shape()[-2]); } - Shape outShape = shapeA; - outShape.set(-1, shapeB[-1]); ABORT_IF(shapeA[-1] != shapeB[-2], - "Batched matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB); - return outShape; + "Batched matrix product requires inner dimensions to match in {}{} * {}{}", + std::string(shapeA), transA, std::string(shapeB), transB); + + // create shapes for batch dimensions only + auto shapeBatchA = shapeA; + shapeBatchA.set(-1, 1); + shapeBatchA.set(-2, 1); + + auto shapeBatchB = shapeB; + shapeBatchB.set(-1, 1); + shapeBatchB.set(-2, 1); + + // broadcast batch dimensions + auto shapeOut = Shape::broadcast({shapeBatchA, shapeBatchB}); + + // set non-batch dimensions in output + shapeOut.set(-1, shapeA[-2]); + shapeOut.set(-2, shapeB[-1]); + + return shapeOut; } NodeOps forwardOps() override { diff --git a/src/tensors/cpu/prod.cpp b/src/tensors/cpu/prod.cpp index 6e28bdd2..066867e4 100755 --- a/src/tensors/cpu/prod.cpp +++ b/src/tensors/cpu/prod.cpp @@ -93,31 +93,58 @@ void ProdBatched(marian::Tensor C, #if BLAS_FOUND float alpha = scalar; - size_t batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]); - size_t batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]); + // determine meta-shape of bdot operation. Essentially treat the last two dimensions as single elements + // such that (..., m, k) x (..., k, n) -> (..., m, n) where ... is a broadcastable shape as in element-wise kernels. - size_t m = A->shape()[-2]; - size_t k = A->shape()[-1]; + auto aShape = A->shape(); + auto bShape = B->shape(); + + // make sure both shape have the same number of dimensions via broadcasting + size_t maxLength = std::max(aShape.size(), bShape.size()); + if(aShape.size() != bShape.size()) { + Shape ones(std::vector(maxLength, 1)); + aShape = Shape::broadcast({aShape, ones}); + bShape = Shape::broadcast({bShape, ones}); + } + + // Create meta-shapes without last 2 dimensions + Shape aShapeMeta, bShapeMeta, cShapeMeta; + aShapeMeta.resize(maxLength - 2); + bShapeMeta.resize(maxLength - 2); + for(size_t i = 0; i < maxLength - 2; ++i) { + aShapeMeta.set(i, aShape[i]); + bShapeMeta.set(i, bShape[i]); + } + cShapeMeta = Shape::broadcast({aShapeMeta, bShapeMeta}); + + size_t m = aShape[-2]; + size_t k = aShape[-1]; if(transA) std::swap(m, k); - size_t l = B->shape()[-2]; - size_t n = B->shape()[-1]; + size_t l = bShape[-2]; + size_t n = bShape[-1]; if(transB) std::swap(l, n); - size_t lda = A->shape()[-1]; - size_t ldb = B->shape()[-1]; - size_t ldc = B->shape()[-1]; + size_t lda = aShape[-1]; + size_t ldb = bShape[-1]; + size_t ldc = bShape[-1]; if(transB) - ldc = B->shape()[-2]; + ldc = bShape[-2]; - auto strideB = batchB == 1 ? 0 : n * k; - auto strideA = batchA == 1 ? 0 : m * k; + auto strideA = m * k; + auto strideB = n * k; auto strideC = n * m; - auto batchC = std::max(batchA, batchB); + auto batchC = cShapeMeta.elements(); + + // Convert to functional shapes to be able to map dimensions. @TODO merge this + functional::Shape aShapeMetaF = aShapeMeta; + functional::Shape bShapeMetaF = bShapeMeta; + functional::Shape cShapeMetaF = cShapeMeta; + #if MKL_FOUND CBLAS_TRANSPOSE transA_forarr = CblasNoTrans; CBLAS_TRANSPOSE transB_forarr = CblasNoTrans; @@ -156,9 +183,14 @@ void ProdBatched(marian::Tensor C, // This loop initializes the array pointers in the same way as the for loop // in the normal sgemm version a few lines below + functional::Array dims; for(size_t i = 0; i < batchC; ++i) { - a_array[i] = A->data() + (i % batchA) * strideA; - b_array[i] = B->data() + (i % batchB) * strideB; + cShapeMetaF.dims(i, dims); + auto aIndex = aShapeMetaF.bindex(dims); + auto bIndex = bShapeMetaF.bindex(dims); + + a_array[i] = A->data() + aIndex * strideA; + b_array[i] = B->data() + bIndex * strideB; c_array[i] = C->data() + i * strideC; } cblas_sgemm_batch (CblasRowMajor, @@ -178,16 +210,21 @@ void ProdBatched(marian::Tensor C, group_count, &group_size[0]); #else + functional::Array dims; for(size_t i = 0; i < batchC; ++i) { + cShapeMetaF.dims(i, dims); + auto aIndex = aShapeMetaF.bindex(dims); + auto bIndex = bShapeMetaF.bindex(dims); + sgemm(transA, transB, (int)m, (int)n, (int)k, alpha, - A->data() + (i % batchA) * strideA, + A->data() + aIndex * strideA, (int)lda, - B->data() + (i % batchB) * strideB, + B->data() + bIndex * strideB, (int)ldb, beta, C->data() + i * strideC, diff --git a/src/tensors/gpu/prod.cpp b/src/tensors/gpu/prod.cpp index 4b49c704..3e35237f 100755 --- a/src/tensors/gpu/prod.cpp +++ b/src/tensors/gpu/prod.cpp @@ -347,25 +347,46 @@ void ProdBatchedTyped(marian::Tensor C, CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no)); ComputeType alpha = scalar; - int batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]); - int batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]); + // determine meta-shape of bdot operation. Essentially treat the last two dimensions as single elements + // such that (..., m, k) x (..., k, n) -> (..., m, n) where ... is a broadcastable shape as in element-wise kernels. - int m = A->shape()[-2]; - int k = A->shape()[-1]; + auto aShape = A->shape(); + auto bShape = B->shape(); + + // make sure both shape have the same number of dimensions via broadcasting + size_t maxLength = std::max(aShape.size(), bShape.size()); + if(aShape.size() != bShape.size()) { + Shape ones(std::vector(maxLength, 1)); + aShape = Shape::broadcast({aShape, ones}); + bShape = Shape::broadcast({bShape, ones}); + } + + // Create meta-shapes without last 2 dimensions + Shape aShapeMeta, bShapeMeta, cShapeMeta; + aShapeMeta.resize(maxLength - 2); + bShapeMeta.resize(maxLength - 2); + for(size_t i = 0; i < maxLength - 2; ++i) { + aShapeMeta.set(i, aShape[i]); + bShapeMeta.set(i, bShape[i]); + } + cShapeMeta = Shape::broadcast({aShapeMeta, bShapeMeta}); + + size_t m = aShape[-2]; + size_t k = aShape[-1]; if(transA) std::swap(m, k); - int l = B->shape()[-2]; - int n = B->shape()[-1]; + size_t l = bShape[-2]; + size_t n = bShape[-1]; if(transB) std::swap(l, n); - int lda = A->shape()[-1]; - int ldb = B->shape()[-1]; - int ldc = B->shape()[-1]; + size_t lda = aShape[-1]; + size_t ldb = bShape[-1]; + size_t ldc = bShape[-1]; if(transB) - ldc = B->shape()[-2]; + ldc = bShape[-2]; cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -374,18 +395,29 @@ void ProdBatchedTyped(marian::Tensor C, auto cublasHandle = backend->getCublasHandle(); auto compute = backend->getCudaComputeCapability(); - auto strideA = batchA == 1 ? 0 : m * k; - auto strideB = batchB == 1 ? 0 : n * k; + auto strideA = m * k; + auto strideB = n * k; auto strideC = n * m; - auto batchC = std::max(batchA, batchB); + + auto batchC = cShapeMeta.elements(); + + // Convert to functional shapes to be able to map dimensions. @TODO merge this + functional::Shape aShapeMetaF = aShapeMeta; + functional::Shape bShapeMetaF = bShapeMeta; + functional::Shape cShapeMetaF = cShapeMeta; std::vector aptr; std::vector bptr; std::vector cptr; + functional::Array dims; for(int i = 0; i < batchC; i++) { - aptr.push_back(A->data() + (i % batchA) * strideA); - bptr.push_back(B->data() + (i % batchB) * strideB); + cShapeMetaF.dims(i, dims); + auto aIndex = aShapeMetaF.bindex(dims); + auto bIndex = bShapeMetaF.bindex(dims); + + aptr.push_back(A->data() + aIndex * strideA); + bptr.push_back(B->data() + bIndex * strideB); cptr.push_back(C->data() + i * strideC); } diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp index 1a18da99..f3b5fda3 100644 --- a/src/tests/units/operator_tests.cpp +++ b/src/tests/units/operator_tests.cpp @@ -615,6 +615,66 @@ void tests(DeviceType device, Type floatType = Type::float32) { CHECK(values2 == values); } + SECTION("bdot") { + graph->clear(); + values.clear(); + + std::vector vA({ 1, 2, + 3, 4, + 5, 6, + 7, 8}); + + std::vector vB({ 1, 2, + 3, 4, + 5, 6, + 7, 8, + 9, 10, + 11, 12}); + + std::vector vC({ 7, 10, + 15, 22, + 19, 22, + 43, 50, + 31, 34, + 71, 78, + 23, 34, + 31, 46, + 67, 78, + 91, 106, + 111, 122, + 151, 166}); + + std::vector vCt({ 5, 11, + 11, 25, + 17, 23, + 39, 53, + 29, 35, + 67, 81, + 17, 39, + 23, 53, + 61, 83, + 83, 113, + 105, 127, + 143, 173}); + + auto A = graph->param("A", {2, 1, 2, 2}, inits::fromVector(vA)); + auto B = graph->param("B", {1, 3, 2, 2}, inits::fromVector(vB)); + + auto C = bdot(A, B, /*transA=*/false, /*transB=*/false); + auto Ct = bdot(A, B, /*transA=*/false, /*transB=*/true); + + graph->forward(); + + CHECK(C->shape() == Shape({2, 3, 2, 2})); + CHECK(Ct->shape() == Shape({2, 3, 2, 2})); + + C->val()->get(values); + CHECK(vC == values); + + Ct->val()->get(values); + CHECK(vCt == values); + } + SECTION("repeat") { graph->clear(); values.clear();