From 55ac6ae8fb11041c945b70bf27247c31e6665b49 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Wed, 6 Dec 2017 18:13:37 +0000 Subject: [PATCH] split out args for functor. Get ready for half2 --- src/amun/half/mblas/matrix_functions.h | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/amun/half/mblas/matrix_functions.h b/src/amun/half/mblas/matrix_functions.h index 06dfc585..8a7a3037 100644 --- a/src/amun/half/mblas/matrix_functions.h +++ b/src/amun/half/mblas/matrix_functions.h @@ -220,8 +220,13 @@ __global__ void gBroadcast(Functor functor, // in2Wrap[beamIdx * cols + stateIdx]); //outWrap[id] = functor(in1Wrap(indices[0], indices[1], 0, batchIdx), // in2Wrap(indices[2], indices[1], 0, 0)); - outWrap(srcId, stateIdx, beamIdx, 0) = functor(in1Wrap(srcId, stateIdx, 0, batchIdx), - in2Wrap(beamIdx, stateIdx, 0, 0)); + //outWrap(srcId, stateIdx, beamIdx, 0) = functor(in1Wrap(srcId, stateIdx, 0, batchIdx), + // in2Wrap(beamIdx, stateIdx, 0, 0)); + const half *in1 = &in1Wrap(srcId, stateIdx, 0, batchIdx); + const half *in2 = &in2Wrap(beamIdx, stateIdx, 0, 0); + half *out = &outWrap(srcId, stateIdx, beamIdx, 0); + *out = functor(*in1, *in2); + } }