mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 12:02:16 +03:00
get ready to shrink size of shape in Broadcast
This commit is contained in:
parent
434303d233
commit
04fe617903
@ -128,20 +128,21 @@ __global__ void gBroadcast(Functor functor,
|
||||
MatrixWrapper<float> out,
|
||||
const MatrixWrapper<float> in1,
|
||||
const MatrixWrapper<float> in2,
|
||||
const VectorWrapper<unsigned> hypo2Batch)
|
||||
const VectorWrapper<unsigned> hypo2Batch,
|
||||
const Shape shape)
|
||||
{
|
||||
unsigned id = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (id < out.GetShape().size()) {
|
||||
/*
|
||||
if (id < shape.size()) {
|
||||
///*
|
||||
unsigned indices[SHAPE_SIZE];
|
||||
out.GetShape().id2Indices(id, indices);
|
||||
shape.id2Indices(id, indices);
|
||||
|
||||
unsigned srcId = indices[0];
|
||||
unsigned stateIdx = indices[1];
|
||||
unsigned beamIdx = indices[2];
|
||||
//assert(0 == indices[3]);
|
||||
*/
|
||||
///*
|
||||
//*/
|
||||
/*
|
||||
unsigned cols = in1.GetShape().dim(1);
|
||||
unsigned srcSize = out.GetShape().dim(0);
|
||||
|
||||
@ -149,19 +150,19 @@ __global__ void gBroadcast(Functor functor,
|
||||
unsigned stateIdx = id % cols;
|
||||
unsigned beamIdx = row / srcSize;
|
||||
unsigned srcId = row % srcSize;
|
||||
//*/
|
||||
*/
|
||||
unsigned batchIdx = hypo2Batch[ beamIdx ];
|
||||
|
||||
assert(srcId < out.GetShape().dim(0));
|
||||
assert(srcId < in1.GetShape().dim(0));
|
||||
assert(beamIdx < in2.GetShape().dim(0));
|
||||
assert(batchIdx < in1.GetShape().dim(3));
|
||||
out[id] = functor(in1[(batchIdx * srcSize + srcId) * cols + stateIdx],
|
||||
in2[beamIdx * cols + stateIdx]);
|
||||
//out[id] = functor(in1[(batchIdx * srcSize + srcId) * cols + stateIdx],
|
||||
// in2[beamIdx * cols + stateIdx]);
|
||||
//out[id] = functor(in1(indices[0], indices[1], 0, batchIdx),
|
||||
// in2(indices[2], indices[1], 0, 0));
|
||||
//out(srcId, stateIdx, beamIdx) = functor(in1(srcId, stateIdx, 0, batchIdx),
|
||||
// in2(beamIdx, stateIdx));
|
||||
out(srcId, stateIdx, beamIdx) = functor(in1(srcId, stateIdx, 0, batchIdx),
|
||||
in2(beamIdx, stateIdx));
|
||||
}
|
||||
}
|
||||
|
||||
@ -191,7 +192,8 @@ Matrix& Broadcast(Functor functor,
|
||||
std::cerr << "srcSize=" << srcSize << " " << activeBatchMaxLength << std::endl;
|
||||
//std::cerr << std::endl;
|
||||
|
||||
unsigned size = out.size();
|
||||
Shape shape(srcSize, cols, sumOfBeamSizes, 1);
|
||||
unsigned size = shape.size();
|
||||
unsigned threads = std::min(MAX_THREADS, size);
|
||||
unsigned blocks = (size / threads) + ((size % threads == 0) ? 0 : 1);
|
||||
/*
|
||||
@ -206,7 +208,7 @@ Matrix& Broadcast(Functor functor,
|
||||
std::cerr << std::endl;
|
||||
*/
|
||||
gBroadcast<<<blocks, threads, 0, CudaStreamHandler::GetStream()>>>
|
||||
(functor, out, in1, in2, hypo2Batch);
|
||||
(functor, out, in1, in2, hypo2Batch, shape);
|
||||
HANDLE_ERROR(cudaGetLastError());
|
||||
|
||||
PAUSE_TIMER("Broadcast");
|
||||
|
Loading…
Reference in New Issue
Block a user