get ready to shrink size of shape in Broadcast

This commit is contained in:
Hieu Hoang 2018-01-20 16:26:17 +00:00
parent 434303d233
commit 04fe617903

View File

@ -128,20 +128,21 @@ __global__ void gBroadcast(Functor functor,
MatrixWrapper<float> out,
const MatrixWrapper<float> in1,
const MatrixWrapper<float> in2,
const VectorWrapper<unsigned> hypo2Batch)
const VectorWrapper<unsigned> hypo2Batch,
const Shape shape)
{
unsigned id = threadIdx.x + blockIdx.x * blockDim.x;
if (id < out.GetShape().size()) {
/*
if (id < shape.size()) {
///*
unsigned indices[SHAPE_SIZE];
out.GetShape().id2Indices(id, indices);
shape.id2Indices(id, indices);
unsigned srcId = indices[0];
unsigned stateIdx = indices[1];
unsigned beamIdx = indices[2];
//assert(0 == indices[3]);
*/
///*
//*/
/*
unsigned cols = in1.GetShape().dim(1);
unsigned srcSize = out.GetShape().dim(0);
@ -149,19 +150,19 @@ __global__ void gBroadcast(Functor functor,
unsigned stateIdx = id % cols;
unsigned beamIdx = row / srcSize;
unsigned srcId = row % srcSize;
//*/
*/
unsigned batchIdx = hypo2Batch[ beamIdx ];
assert(srcId < out.GetShape().dim(0));
assert(srcId < in1.GetShape().dim(0));
assert(beamIdx < in2.GetShape().dim(0));
assert(batchIdx < in1.GetShape().dim(3));
out[id] = functor(in1[(batchIdx * srcSize + srcId) * cols + stateIdx],
in2[beamIdx * cols + stateIdx]);
//out[id] = functor(in1[(batchIdx * srcSize + srcId) * cols + stateIdx],
// in2[beamIdx * cols + stateIdx]);
//out[id] = functor(in1(indices[0], indices[1], 0, batchIdx),
// in2(indices[2], indices[1], 0, 0));
//out(srcId, stateIdx, beamIdx) = functor(in1(srcId, stateIdx, 0, batchIdx),
// in2(beamIdx, stateIdx));
out(srcId, stateIdx, beamIdx) = functor(in1(srcId, stateIdx, 0, batchIdx),
in2(beamIdx, stateIdx));
}
}
@ -191,7 +192,8 @@ Matrix& Broadcast(Functor functor,
std::cerr << "srcSize=" << srcSize << " " << activeBatchMaxLength << std::endl;
//std::cerr << std::endl;
unsigned size = out.size();
Shape shape(srcSize, cols, sumOfBeamSizes, 1);
unsigned size = shape.size();
unsigned threads = std::min(MAX_THREADS, size);
unsigned blocks = (size / threads) + ((size % threads == 0) ? 0 : 1);
/*
@ -206,7 +208,7 @@ Matrix& Broadcast(Functor functor,
std::cerr << std::endl;
*/
gBroadcast<<<blocks, threads, 0, CudaStreamHandler::GetStream()>>>
(functor, out, in1, in2, hypo2Batch);
(functor, out, in1, in2, hypo2Batch, shape);
HANDLE_ERROR(cudaGetLastError());
PAUSE_TIMER("Broadcast");