mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-27 10:33:14 +03:00
move VectorWrapper to its own file
This commit is contained in:
parent
8cf62a5741
commit
c6f66d4dc0
@ -1595,6 +1595,11 @@
|
||||
<type>1</type>
|
||||
<locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/mblas/vector.h</locationURI>
|
||||
</link>
|
||||
<link>
|
||||
<name>src/amun/gpu/mblas/vector_wrapper.h</name>
|
||||
<type>1</type>
|
||||
<locationURI>PARENT-3-PROJECT_LOC/src/amun/gpu/mblas/vector_wrapper.h</locationURI>
|
||||
</link>
|
||||
<link>
|
||||
<name>src/amun/3rd_party/blaze/config/Assertion.h</name>
|
||||
<type>1</type>
|
||||
|
@ -447,7 +447,7 @@ __global__ void gSoftMax(MatrixWrapper<float> out,
|
||||
int origSrcPos = threadIdx.x;
|
||||
|
||||
while (hypoInd < numHypos) {
|
||||
MatrixWrapper<float> _max(_share, shareSize, 1, 1, 1);
|
||||
VectorWrapper<float> _max(_share, shareSize);
|
||||
_max[origSrcPos] = out(hypoInd, origSrcPos, 0, 0);
|
||||
for (int tid = 0; tid < maxLength; tid += blockDim.x) {
|
||||
int srcPos = tid + origSrcPos;
|
||||
@ -550,7 +550,7 @@ __global__ void gLogSoftMax(MatrixWrapper<float> out, uint shareSize)
|
||||
|
||||
while (rowIdx < rows) {
|
||||
//float* _max = _share;
|
||||
MatrixWrapper<float> _max(_share, shareSize, 1, 1, 1);
|
||||
VectorWrapper<float> _max(_share, shareSize);
|
||||
|
||||
_max[threadIdx.x] = out(rowIdx, threadIdx.x, 0, 0);
|
||||
for (int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
@ -1019,7 +1019,7 @@ void NBestAndMax(VectorWrapper<NthOutBatch> nBestCandidatesWrap,
|
||||
{
|
||||
extern __shared__ char _sharePtr[];
|
||||
|
||||
MatrixWrapper<float> maxMatrix((float*)_sharePtr, blockDim.x, 1, 1, 1);
|
||||
VectorWrapper<float> maxVec((float*)_sharePtr, blockDim.x);
|
||||
|
||||
void *ptrOffset = _sharePtr + sizeof(float) * blockDim.x;
|
||||
MatrixWrapper<NthOutBatch> nBestMatrix((NthOutBatch*)ptrOffset, blockDim.x, maxBeamSize, 1, 1);
|
||||
@ -1113,7 +1113,7 @@ void SumAndLogSoftMax(VectorWrapper<NthOutBatch> nBestCandidatesWrap,
|
||||
//assert(nBestCandidatesWrap.dim(0) == rows);
|
||||
|
||||
//float* _sum = _share;// + blockDim.x;
|
||||
MatrixWrapper<float> _sum(_share, blockDim.x, 1, 1, 1);
|
||||
VectorWrapper<float> _sum(_share, blockDim.x);
|
||||
|
||||
// calc sum
|
||||
_sum[threadIdx.x] = 0.0f;
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include "gpu/mblas/matrix_wrapper.h"
|
||||
#include "gpu/mblas/handles.h"
|
||||
#include "gpu/mblas/nth_element_kernels.h"
|
||||
#include "gpu/mblas/vector_wrapper.h"
|
||||
|
||||
namespace amunmt {
|
||||
namespace GPU {
|
||||
@ -236,7 +237,7 @@ __global__ void gBroadcastVecColumn(Functor functor,
|
||||
size_t rows = outWrap.dim(0);
|
||||
size_t cols = outWrap.dim(1);
|
||||
|
||||
MatrixWrapper<float> sdata(sdataOrig, rows, 1, 1, 1);
|
||||
VectorWrapper<float> sdata(sdataOrig, rows);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < rows; ++i)
|
||||
|
@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
#include "matrix.h"
|
||||
#include "gpu/mblas/vector.h"
|
||||
|
||||
namespace amunmt {
|
||||
namespace GPU {
|
||||
@ -247,66 +246,6 @@ inline void testidToMatrixInd()
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
class VectorWrapper
|
||||
{
|
||||
public:
|
||||
VectorWrapper(const Vector<T> &vec)
|
||||
{
|
||||
size_ = vec.size();
|
||||
data_ = nullptr;
|
||||
dataConst_ = vec.data();
|
||||
}
|
||||
|
||||
VectorWrapper(Vector<T> &vec)
|
||||
{
|
||||
size_ = vec.size();
|
||||
data_ = vec.data();
|
||||
dataConst_ = vec.data();
|
||||
}
|
||||
|
||||
__device__ __host__
|
||||
uint size() const
|
||||
{
|
||||
return size_;
|
||||
}
|
||||
|
||||
__device__
|
||||
T* data()
|
||||
{
|
||||
assert(data_);
|
||||
return data_;
|
||||
}
|
||||
|
||||
__device__
|
||||
const T* data() const
|
||||
{
|
||||
assert(dataConst_);
|
||||
return dataConst_;
|
||||
}
|
||||
|
||||
__device__
|
||||
const T &operator[](uint i) const
|
||||
{
|
||||
assert(i < size());
|
||||
return data()[i];
|
||||
}
|
||||
|
||||
__device__
|
||||
T &operator[](uint i)
|
||||
{
|
||||
assert(i < size());
|
||||
return data()[i];
|
||||
}
|
||||
|
||||
protected:
|
||||
uint size_;
|
||||
|
||||
T *data_;
|
||||
const T *dataConst_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <iostream>
|
||||
#include "common/utils.h"
|
||||
#include "matrix_wrapper.h"
|
||||
#include "vector_wrapper.h"
|
||||
#include "nth_element.h"
|
||||
#include "matrix_functions.h"
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "matrix_wrapper.h"
|
||||
#include "vector_wrapper.h"
|
||||
|
||||
namespace amunmt {
|
||||
namespace GPU {
|
||||
|
82
src/amun/gpu/mblas/vector_wrapper.h
Normal file
82
src/amun/gpu/mblas/vector_wrapper.h
Normal file
@ -0,0 +1,82 @@
|
||||
#pragma once
|
||||
#include "matrix.h"
|
||||
#include "gpu/mblas/vector.h"
|
||||
|
||||
namespace amunmt {
|
||||
namespace GPU {
|
||||
namespace mblas {
|
||||
|
||||
|
||||
template <typename T>
|
||||
class VectorWrapper
|
||||
{
|
||||
public:
|
||||
VectorWrapper(const Vector<T> &vec)
|
||||
{
|
||||
size_ = vec.size();
|
||||
data_ = nullptr;
|
||||
dataConst_ = vec.data();
|
||||
}
|
||||
|
||||
VectorWrapper(Vector<T> &vec)
|
||||
{
|
||||
size_ = vec.size();
|
||||
data_ = vec.data();
|
||||
dataConst_ = vec.data();
|
||||
}
|
||||
|
||||
__device__
|
||||
VectorWrapper(T *ptr, uint size)
|
||||
{
|
||||
size_ = size;
|
||||
data_ = ptr;
|
||||
dataConst_ = ptr;
|
||||
}
|
||||
|
||||
__device__ __host__
|
||||
uint size() const
|
||||
{
|
||||
return size_;
|
||||
}
|
||||
|
||||
__device__
|
||||
T* data()
|
||||
{
|
||||
assert(data_);
|
||||
return data_;
|
||||
}
|
||||
|
||||
__device__
|
||||
const T* data() const
|
||||
{
|
||||
assert(dataConst_);
|
||||
return dataConst_;
|
||||
}
|
||||
|
||||
__device__
|
||||
const T &operator[](uint i) const
|
||||
{
|
||||
assert(i < size());
|
||||
return data()[i];
|
||||
}
|
||||
|
||||
__device__
|
||||
T &operator[](uint i)
|
||||
{
|
||||
assert(i < size());
|
||||
return data()[i];
|
||||
}
|
||||
|
||||
protected:
|
||||
uint size_;
|
||||
|
||||
T *data_;
|
||||
const T *dataConst_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user