mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
broadcasting bdot
This commit is contained in:
parent
2c1b16f43e
commit
77c0cac1f2
@ -529,11 +529,27 @@ public:
|
||||
shapeB.set(-1, b->shape()[-2]);
|
||||
}
|
||||
|
||||
Shape outShape = shapeA;
|
||||
outShape.set(-1, shapeB[-1]);
|
||||
ABORT_IF(shapeA[-1] != shapeB[-2],
|
||||
"Batched matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
|
||||
return outShape;
|
||||
"Batched matrix product requires inner dimensions to match in {}{} * {}{}",
|
||||
std::string(shapeA), transA, std::string(shapeB), transB);
|
||||
|
||||
// create shapes for batch dimensions only
|
||||
auto shapeBatchA = shapeA;
|
||||
shapeBatchA.set(-1, 1);
|
||||
shapeBatchA.set(-2, 1);
|
||||
|
||||
auto shapeBatchB = shapeB;
|
||||
shapeBatchB.set(-1, 1);
|
||||
shapeBatchB.set(-2, 1);
|
||||
|
||||
// broadcast batch dimensions
|
||||
auto shapeOut = Shape::broadcast({shapeBatchA, shapeBatchB});
|
||||
|
||||
// set non-batch dimensions in output
|
||||
shapeOut.set(-1, shapeA[-2]);
|
||||
shapeOut.set(-2, shapeB[-1]);
|
||||
|
||||
return shapeOut;
|
||||
}
|
||||
|
||||
NodeOps forwardOps() override {
|
||||
|
@ -93,31 +93,58 @@ void ProdBatched(marian::Tensor C,
|
||||
#if BLAS_FOUND
|
||||
float alpha = scalar;
|
||||
|
||||
size_t batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]);
|
||||
size_t batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]);
|
||||
// determine meta-shape of bdot operation. Essentially treat the last two dimensions as single elements
|
||||
// such that (..., m, k) x (..., k, n) -> (..., m, n) where ... is a broadcastable shape as in element-wise kernels.
|
||||
|
||||
size_t m = A->shape()[-2];
|
||||
size_t k = A->shape()[-1];
|
||||
auto aShape = A->shape();
|
||||
auto bShape = B->shape();
|
||||
|
||||
// make sure both shape have the same number of dimensions via broadcasting
|
||||
size_t maxLength = std::max(aShape.size(), bShape.size());
|
||||
if(aShape.size() != bShape.size()) {
|
||||
Shape ones(std::vector<int>(maxLength, 1));
|
||||
aShape = Shape::broadcast({aShape, ones});
|
||||
bShape = Shape::broadcast({bShape, ones});
|
||||
}
|
||||
|
||||
// Create meta-shapes without last 2 dimensions
|
||||
Shape aShapeMeta, bShapeMeta, cShapeMeta;
|
||||
aShapeMeta.resize(maxLength - 2);
|
||||
bShapeMeta.resize(maxLength - 2);
|
||||
for(size_t i = 0; i < maxLength - 2; ++i) {
|
||||
aShapeMeta.set(i, aShape[i]);
|
||||
bShapeMeta.set(i, bShape[i]);
|
||||
}
|
||||
cShapeMeta = Shape::broadcast({aShapeMeta, bShapeMeta});
|
||||
|
||||
size_t m = aShape[-2];
|
||||
size_t k = aShape[-1];
|
||||
if(transA)
|
||||
std::swap(m, k);
|
||||
|
||||
size_t l = B->shape()[-2];
|
||||
size_t n = B->shape()[-1];
|
||||
size_t l = bShape[-2];
|
||||
size_t n = bShape[-1];
|
||||
if(transB)
|
||||
std::swap(l, n);
|
||||
|
||||
size_t lda = A->shape()[-1];
|
||||
size_t ldb = B->shape()[-1];
|
||||
size_t ldc = B->shape()[-1];
|
||||
size_t lda = aShape[-1];
|
||||
size_t ldb = bShape[-1];
|
||||
size_t ldc = bShape[-1];
|
||||
|
||||
if(transB)
|
||||
ldc = B->shape()[-2];
|
||||
ldc = bShape[-2];
|
||||
|
||||
auto strideB = batchB == 1 ? 0 : n * k;
|
||||
auto strideA = batchA == 1 ? 0 : m * k;
|
||||
auto strideA = m * k;
|
||||
auto strideB = n * k;
|
||||
auto strideC = n * m;
|
||||
|
||||
auto batchC = std::max(batchA, batchB);
|
||||
auto batchC = cShapeMeta.elements();
|
||||
|
||||
// Convert to functional shapes to be able to map dimensions. @TODO merge this
|
||||
functional::Shape aShapeMetaF = aShapeMeta;
|
||||
functional::Shape bShapeMetaF = bShapeMeta;
|
||||
functional::Shape cShapeMetaF = cShapeMeta;
|
||||
|
||||
#if MKL_FOUND
|
||||
CBLAS_TRANSPOSE transA_forarr = CblasNoTrans;
|
||||
CBLAS_TRANSPOSE transB_forarr = CblasNoTrans;
|
||||
@ -156,9 +183,14 @@ void ProdBatched(marian::Tensor C,
|
||||
|
||||
// This loop initializes the array pointers in the same way as the for loop
|
||||
// in the normal sgemm version a few lines below
|
||||
functional::Array<int, functional::Shape::size()> dims;
|
||||
for(size_t i = 0; i < batchC; ++i) {
|
||||
a_array[i] = A->data() + (i % batchA) * strideA;
|
||||
b_array[i] = B->data() + (i % batchB) * strideB;
|
||||
cShapeMetaF.dims(i, dims);
|
||||
auto aIndex = aShapeMetaF.bindex(dims);
|
||||
auto bIndex = bShapeMetaF.bindex(dims);
|
||||
|
||||
a_array[i] = A->data() + aIndex * strideA;
|
||||
b_array[i] = B->data() + bIndex * strideB;
|
||||
c_array[i] = C->data() + i * strideC;
|
||||
}
|
||||
cblas_sgemm_batch (CblasRowMajor,
|
||||
@ -178,16 +210,21 @@ void ProdBatched(marian::Tensor C,
|
||||
group_count,
|
||||
&group_size[0]);
|
||||
#else
|
||||
functional::Array<int, functional::Shape::size()> dims;
|
||||
for(size_t i = 0; i < batchC; ++i) {
|
||||
cShapeMetaF.dims(i, dims);
|
||||
auto aIndex = aShapeMetaF.bindex(dims);
|
||||
auto bIndex = bShapeMetaF.bindex(dims);
|
||||
|
||||
sgemm(transA,
|
||||
transB,
|
||||
(int)m,
|
||||
(int)n,
|
||||
(int)k,
|
||||
alpha,
|
||||
A->data() + (i % batchA) * strideA,
|
||||
A->data() + aIndex * strideA,
|
||||
(int)lda,
|
||||
B->data() + (i % batchB) * strideB,
|
||||
B->data() + bIndex * strideB,
|
||||
(int)ldb,
|
||||
beta,
|
||||
C->data() + i * strideC,
|
||||
|
@ -347,25 +347,46 @@ void ProdBatchedTyped(marian::Tensor C,
|
||||
CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no));
|
||||
ComputeType alpha = scalar;
|
||||
|
||||
int batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]);
|
||||
int batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]);
|
||||
// determine meta-shape of bdot operation. Essentially treat the last two dimensions as single elements
|
||||
// such that (..., m, k) x (..., k, n) -> (..., m, n) where ... is a broadcastable shape as in element-wise kernels.
|
||||
|
||||
int m = A->shape()[-2];
|
||||
int k = A->shape()[-1];
|
||||
auto aShape = A->shape();
|
||||
auto bShape = B->shape();
|
||||
|
||||
// make sure both shape have the same number of dimensions via broadcasting
|
||||
size_t maxLength = std::max(aShape.size(), bShape.size());
|
||||
if(aShape.size() != bShape.size()) {
|
||||
Shape ones(std::vector<int>(maxLength, 1));
|
||||
aShape = Shape::broadcast({aShape, ones});
|
||||
bShape = Shape::broadcast({bShape, ones});
|
||||
}
|
||||
|
||||
// Create meta-shapes without last 2 dimensions
|
||||
Shape aShapeMeta, bShapeMeta, cShapeMeta;
|
||||
aShapeMeta.resize(maxLength - 2);
|
||||
bShapeMeta.resize(maxLength - 2);
|
||||
for(size_t i = 0; i < maxLength - 2; ++i) {
|
||||
aShapeMeta.set(i, aShape[i]);
|
||||
bShapeMeta.set(i, bShape[i]);
|
||||
}
|
||||
cShapeMeta = Shape::broadcast({aShapeMeta, bShapeMeta});
|
||||
|
||||
size_t m = aShape[-2];
|
||||
size_t k = aShape[-1];
|
||||
if(transA)
|
||||
std::swap(m, k);
|
||||
|
||||
int l = B->shape()[-2];
|
||||
int n = B->shape()[-1];
|
||||
size_t l = bShape[-2];
|
||||
size_t n = bShape[-1];
|
||||
if(transB)
|
||||
std::swap(l, n);
|
||||
|
||||
int lda = A->shape()[-1];
|
||||
int ldb = B->shape()[-1];
|
||||
int ldc = B->shape()[-1];
|
||||
size_t lda = aShape[-1];
|
||||
size_t ldb = bShape[-1];
|
||||
size_t ldc = bShape[-1];
|
||||
|
||||
if(transB)
|
||||
ldc = B->shape()[-2];
|
||||
ldc = bShape[-2];
|
||||
|
||||
cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
@ -374,18 +395,29 @@ void ProdBatchedTyped(marian::Tensor C,
|
||||
auto cublasHandle = backend->getCublasHandle();
|
||||
auto compute = backend->getCudaComputeCapability();
|
||||
|
||||
auto strideA = batchA == 1 ? 0 : m * k;
|
||||
auto strideB = batchB == 1 ? 0 : n * k;
|
||||
auto strideA = m * k;
|
||||
auto strideB = n * k;
|
||||
auto strideC = n * m;
|
||||
auto batchC = std::max(batchA, batchB);
|
||||
|
||||
auto batchC = cShapeMeta.elements();
|
||||
|
||||
// Convert to functional shapes to be able to map dimensions. @TODO merge this
|
||||
functional::Shape aShapeMetaF = aShapeMeta;
|
||||
functional::Shape bShapeMetaF = bShapeMeta;
|
||||
functional::Shape cShapeMetaF = cShapeMeta;
|
||||
|
||||
std::vector<const ElementType*> aptr;
|
||||
std::vector<const ElementType*> bptr;
|
||||
std::vector<ElementType*> cptr;
|
||||
|
||||
functional::Array<int, functional::Shape::size()> dims;
|
||||
for(int i = 0; i < batchC; i++) {
|
||||
aptr.push_back(A->data<ElementType>() + (i % batchA) * strideA);
|
||||
bptr.push_back(B->data<ElementType>() + (i % batchB) * strideB);
|
||||
cShapeMetaF.dims(i, dims);
|
||||
auto aIndex = aShapeMetaF.bindex(dims);
|
||||
auto bIndex = bShapeMetaF.bindex(dims);
|
||||
|
||||
aptr.push_back(A->data<ElementType>() + aIndex * strideA);
|
||||
bptr.push_back(B->data<ElementType>() + bIndex * strideB);
|
||||
cptr.push_back(C->data<ElementType>() + i * strideC);
|
||||
}
|
||||
|
||||
|
@ -615,6 +615,66 @@ void tests(DeviceType device, Type floatType = Type::float32) {
|
||||
CHECK(values2 == values);
|
||||
}
|
||||
|
||||
SECTION("bdot") {
|
||||
graph->clear();
|
||||
values.clear();
|
||||
|
||||
std::vector<T> vA({ 1, 2,
|
||||
3, 4,
|
||||
5, 6,
|
||||
7, 8});
|
||||
|
||||
std::vector<T> vB({ 1, 2,
|
||||
3, 4,
|
||||
5, 6,
|
||||
7, 8,
|
||||
9, 10,
|
||||
11, 12});
|
||||
|
||||
std::vector<T> vC({ 7, 10,
|
||||
15, 22,
|
||||
19, 22,
|
||||
43, 50,
|
||||
31, 34,
|
||||
71, 78,
|
||||
23, 34,
|
||||
31, 46,
|
||||
67, 78,
|
||||
91, 106,
|
||||
111, 122,
|
||||
151, 166});
|
||||
|
||||
std::vector<T> vCt({ 5, 11,
|
||||
11, 25,
|
||||
17, 23,
|
||||
39, 53,
|
||||
29, 35,
|
||||
67, 81,
|
||||
17, 39,
|
||||
23, 53,
|
||||
61, 83,
|
||||
83, 113,
|
||||
105, 127,
|
||||
143, 173});
|
||||
|
||||
auto A = graph->param("A", {2, 1, 2, 2}, inits::fromVector(vA));
|
||||
auto B = graph->param("B", {1, 3, 2, 2}, inits::fromVector(vB));
|
||||
|
||||
auto C = bdot(A, B, /*transA=*/false, /*transB=*/false);
|
||||
auto Ct = bdot(A, B, /*transA=*/false, /*transB=*/true);
|
||||
|
||||
graph->forward();
|
||||
|
||||
CHECK(C->shape() == Shape({2, 3, 2, 2}));
|
||||
CHECK(Ct->shape() == Shape({2, 3, 2, 2}));
|
||||
|
||||
C->val()->get(values);
|
||||
CHECK(vC == values);
|
||||
|
||||
Ct->val()->get(values);
|
||||
CHECK(vCt == values);
|
||||
}
|
||||
|
||||
SECTION("repeat") {
|
||||
graph->clear();
|
||||
values.clear();
|
||||
|
Loading…
Reference in New Issue
Block a user