mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-27 10:33:14 +03:00
don't use shape for now but leave code in
This commit is contained in:
parent
d88cbf7f82
commit
1549edefa8
@ -128,12 +128,12 @@ __global__ void gBroadcast(Functor functor,
|
||||
MatrixWrapper<float> out,
|
||||
const MatrixWrapper<float> in1,
|
||||
const MatrixWrapper<float> in2,
|
||||
const VectorWrapper<unsigned> hypo2Batch,
|
||||
const Shape shape)
|
||||
const VectorWrapper<unsigned> hypo2Batch)
|
||||
{
|
||||
const Shape &shape = out.GetShape();
|
||||
unsigned id = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (id < shape.size()) {
|
||||
///*
|
||||
/*
|
||||
unsigned indices[SHAPE_SIZE];
|
||||
shape.id2Indices(id, indices);
|
||||
|
||||
@ -141,8 +141,8 @@ __global__ void gBroadcast(Functor functor,
|
||||
unsigned stateIdx = indices[1];
|
||||
unsigned beamIdx = indices[2];
|
||||
//assert(0 == indices[3]);
|
||||
//*/
|
||||
/*
|
||||
*/
|
||||
///*
|
||||
unsigned cols = in1.GetShape().dim(1);
|
||||
unsigned srcSize = out.GetShape().dim(0);
|
||||
|
||||
@ -150,19 +150,19 @@ __global__ void gBroadcast(Functor functor,
|
||||
unsigned stateIdx = id % cols;
|
||||
unsigned beamIdx = row / srcSize;
|
||||
unsigned srcId = row % srcSize;
|
||||
*/
|
||||
//*/
|
||||
unsigned batchIdx = hypo2Batch[ beamIdx ];
|
||||
|
||||
assert(srcId < out.GetShape().dim(0));
|
||||
assert(srcId < in1.GetShape().dim(0));
|
||||
assert(beamIdx < in2.GetShape().dim(0));
|
||||
assert(batchIdx < in1.GetShape().dim(3));
|
||||
//out[id] = functor(in1[(batchIdx * srcSize + srcId) * cols + stateIdx],
|
||||
// in2[beamIdx * cols + stateIdx]);
|
||||
out[id] = functor(in1[(batchIdx * srcSize + srcId) * cols + stateIdx],
|
||||
in2[beamIdx * cols + stateIdx]);
|
||||
//out[id] = functor(in1(indices[0], indices[1], 0, batchIdx),
|
||||
// in2(indices[2], indices[1], 0, 0));
|
||||
out(srcId, stateIdx, beamIdx) = functor(in1(srcId, stateIdx, 0, batchIdx),
|
||||
in2(beamIdx, stateIdx));
|
||||
//out(srcId, stateIdx, beamIdx) = functor(in1(srcId, stateIdx, 0, batchIdx),
|
||||
// in2(beamIdx, stateIdx));
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,12 +188,13 @@ Matrix& Broadcast(Functor functor,
|
||||
std::cerr << "in1=" << in1.Debug(0) << std::endl;
|
||||
std::cerr << "in2=" << in2.Debug(0) << std::endl;
|
||||
std::cerr << "hypo2Batch=" << hypo2Batch.Debug(0) << std::endl;
|
||||
*/
|
||||
std::cerr << "srcSize=" << srcSize << " " << activeBatchMaxLength << std::endl;
|
||||
//std::cerr << std::endl;
|
||||
std::cerr << std::endl;
|
||||
|
||||
Shape shape(activeBatchMaxLength, cols, sumOfBeamSizes, 1);
|
||||
unsigned size = shape.size();
|
||||
*/
|
||||
|
||||
unsigned size = out.size();
|
||||
unsigned threads = std::min(MAX_THREADS, size);
|
||||
unsigned blocks = (size / threads) + ((size % threads == 0) ? 0 : 1);
|
||||
/*
|
||||
@ -208,7 +209,7 @@ Matrix& Broadcast(Functor functor,
|
||||
std::cerr << std::endl;
|
||||
*/
|
||||
gBroadcast<<<blocks, threads, 0, CudaStreamHandler::GetStream()>>>
|
||||
(functor, out, in1, in2, hypo2Batch, shape);
|
||||
(functor, out, in1, in2, hypo2Batch);
|
||||
HANDLE_ERROR(cudaGetLastError());
|
||||
|
||||
PAUSE_TIMER("Broadcast");
|
||||
|
Loading…
Reference in New Issue
Block a user