merged dot_csr() change

This commit is contained in:
Frank Seide 2019-01-15 16:45:49 -08:00
commit 2e331b254d
11 changed files with 159 additions and 130 deletions

View File

@ -104,7 +104,7 @@ if(CUDA_FOUND)
cuda_add_library(marian_cuda
tensors/gpu/device.cu
tensors/gpu/algorithm.cu
tensors/gpu/prod.cu
tensors/gpu/prod.cpp
tensors/gpu/element.cu
tensors/gpu/add.cu
tensors/gpu/tensor_operators.cu

View File

@ -421,7 +421,13 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
// A[i,j] is at A_values[A_offsets[i]+k], where k is position of j in A_indices[A_offsets[i]:A_offsets[i+1]]
// @TODO: Define a proper sparse tensor type.
Expr csr_dot(const Shape& A_shape, Expr A_values, Expr A_indices, Expr A_offsets, Expr B, bool transA /*= false*/) {
return Expression<CSRDotNodeOp>(A_shape, A_values, A_indices, A_offsets, B, transA);
return Expression<CSRDotNodeOp>(A_shape, A_values, A_indices, A_offsets, B, transA, /*swapOperands=*/false);
}
// multiply a matrix A with a CSR matrix B
// @TODO: Define a proper sparse tensor type.
Expr dot_csr(Expr A, const Shape& B_shape, Expr B_values, Expr B_indices, Expr B_offsets, bool transB /*= false*/) {
return Expression<CSRDotNodeOp>(B_shape, B_values, B_indices, B_offsets, A, transB, /*swapOperands=*/true);
}
// swap the last two axes

View File

@ -113,6 +113,7 @@ Expr affine(Expr a,
float scalar = 1.f);
Expr csr_dot(const Shape& A_shape, Expr Avalues, Expr Aindices, Expr Aoffsets, Expr B, bool transA = false);
Expr dot_csr(Expr A, const Shape& B_shape, Expr B_values, Expr B_indices, Expr B_offsets, bool transB = false);
Expr transpose(Expr a);
Expr transpose(Expr a, const std::vector<int>& axes);

View File

@ -42,7 +42,7 @@ public:
Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"Matrix product requires dimensions to match");
"Matrix product requires inner dimensions to match");
return outShape;
}
@ -165,7 +165,7 @@ public:
Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
"Matrix product requires dimensions to match");
"Matrix product requires inner dimensions to match");
return outShape;
}
@ -309,7 +309,7 @@ public:
Shape outShape = shapeA;
outShape.set(-1, shapeB[-1]);
ABORT_IF(shapeA[-1] != shapeB[-2],
"Batched matrix product requires dimensions to match");
"Batched matrix product requires inner dimensions to match");
return outShape;
}
@ -409,49 +409,52 @@ public:
const std::string color() override { return "orange"; }
};
// Note: To reduce code duplication, we use the same NodeOp for C = op(S) x D and C = D x op(S).
// Set swapOperands to select the latter.
class CSRDotNodeOp : public NaryNodeOp {
bool transA_;
bool transS_;
bool swapOperands_;
public:
CSRDotNodeOp(const Shape& A_shape, Expr A_values, Expr A_indices, Expr A_offsets, Expr B, bool transA)
: NaryNodeOp({ A_values, A_indices, A_offsets, B }, newShape(A_shape, A_values, A_indices, A_offsets, B, transA)), transA_(transA) {
matchOrAbort<IndexType>(A_indices->value_type());
matchOrAbort<IndexType>(A_offsets->value_type());
CSRDotNodeOp(const Shape& S_shape, Expr S_values, Expr S_indices, Expr S_offsets, Expr D, bool transS, bool swapOperands)
: NaryNodeOp({ S_values, S_indices, S_offsets, D }, newShape(S_shape, S_values, S_indices, S_offsets, D, transS, swapOperands)),
transS_(transS), swapOperands_(swapOperands){
matchOrAbort<IndexType>(S_indices->value_type());
matchOrAbort<IndexType>(S_offsets->value_type());
}
Shape newShape(const Shape& A_shape, Expr A_values, Expr A_indices, Expr A_offsets, Expr B, bool transA) {
ABORT_IF(A_values->shape().size() != 1 || A_indices->shape().size() != 1 || A_offsets->shape().size() != 1,
Shape newShape(const Shape& S_shape, Expr S_values, Expr S_indices, Expr S_offsets, Expr D, bool transS, bool swapOperands) {
ABORT_IF(S_values->shape().size() != 1 || S_indices->shape().size() != 1 || S_offsets->shape().size() != 1,
"Sparse matrix components must all be vectors");
ABORT_IF(A_values->shape() != A_indices->shape(),
ABORT_IF(S_values->shape() != S_indices->shape(),
"Sparse matrix values and indices must have the same shape");
ABORT_IF(A_shape.size() != 2,
ABORT_IF(S_shape.size() != 2,
"Sparse matrix must have rank 2");
ABORT_IF(A_offsets->shape()[0] - 1 != A_shape[0],
ABORT_IF(S_offsets->shape()[0] - 1 != S_shape[0],
"Sparse matrix offset vector has incorrect size");
auto outShape = B->shape();
ABORT_IF((transA ? A_shape[0] : A_shape[1] != B->shape()[0]),
"Matrix product requires dimensions to match");
outShape.set(0, transA ? A_shape[1] : A_shape[0]);
auto outShape = D->shape();
ABORT_IF(S_shape[transS == swapOperands ? 1 : 0] != outShape[-(int)swapOperands],
"Matrix product requires inner dimensions to match");
outShape.set(-(int)swapOperands, S_shape[transS != swapOperands]);
return outShape;
}
NodeOps forwardOps() override {
// C = dot(A, B)
return {NodeOp(CSRProd(val_,
graph()->allocator(),
child(0)->val(), child(1)->val(), child(2)->val(),
child(3)->val(),
/*transA=*/transA_, /*beta=*/0))};
/*transS=*/transS_, /*swapOperands=*/swapOperands_, /*beta=*/0))};
}
NodeOps backwardOps() override {
return {nullptr, // can't backprop into the sparse matrix pieces (the gradient is dense)
return {nullptr, // can't backprop into the sparse matrix (the gradient is dense)
nullptr,
nullptr,
NodeOp(CSRProd(child(3)->grad(), // child(3) = B
NodeOp(CSRProd(child(3)->grad(), // child(3) = D
graph()->allocator(),
child(0)->val(), child(1)->val(), child(2)->val(), // children(0..2) = A
adj_,
/*transA=*/!transA_, /*beta=*/1))};
/*transS=*/!transS_, /*swapOperands=*/swapOperands_, /*beta=*/1))};
}
const std::string type() override { return "csr_dot"; }

View File

@ -174,8 +174,9 @@ void CSRProd(marian::Tensor C,
const marian::Tensor& A_offsets,
const marian::Tensor& B,
bool transA,
bool swapOperands,
float beta) {
C, A_values, A_indices, A_offsets, B, transA, beta;
C, A_values, A_indices, A_offsets, B, transA, swapOperands, beta;
ABORT("CSRProd is not yet implemented for CPU");
}

View File

@ -95,6 +95,7 @@ void Prod(marian::Tensor C,
#endif
}
#if 0 // @TODO: remove, then rename from .cu to .cpp
__global__ void gAddBias(float* out,
const float* bias,
size_t length,
@ -108,7 +109,6 @@ __global__ void gAddBias(float* out,
}
}
#if 0 // @TODO: remove, then rename from .cu to .cpp
void AddBias(marian::Tensor C, const marian::Tensor bias) {
cudaSetDevice(C->getDeviceId().no);
@ -227,113 +227,112 @@ void ProdBatched(marian::Tensor C,
allocator->free(mp_cptr);
}
// C = op(S) x D if not swapOperands else C = D x op(S)
// op(S) = S if not transA else S^T
void CSRProd(marian::Tensor C,
Ptr<Allocator> allocator,
const marian::Tensor& A_values,
const marian::Tensor& A_indices,
const marian::Tensor& A_offsets,
const marian::Tensor& B,
bool transA,
const marian::Tensor& S_values,
const marian::Tensor& S_indices,
const marian::Tensor& S_offsets,
const marian::Tensor& D,
bool transS,
bool swapOperands,
float beta) {
cudaSetDevice(C->getDeviceId().no);
auto cusparseHandle = std::static_pointer_cast<gpu::Backend>(C->getBackend())
->getCusparseHandle();
// dimensions
// interpret tensor dimensions as matrix dimensions
const auto& shapeC = C->shape();
const auto& shapeB = B->shape();
auto rowsC = shapeC[0];
const auto& shapeD = D->shape();
// If swapOperands, S and D are swapped (C = D x S instead of C = S x D).
// In that case, in the next 6 lines, please read all dimensions as if they were reversed in order.
auto rowsC = shapeC[-(int)swapOperands];
auto colsC = shapeC.elements() / rowsC;
auto rowsB = shapeB[0];
auto colsB = shapeB.elements() / rowsB;
auto rowsA = transA ? rowsB : rowsC;
auto colsA = transA ? rowsC : rowsB;
ABORT_IF((transA ? colsA : rowsA) != rowsC || (transA ? rowsA : colsA) != rowsB || colsB != colsC, "Inconsistent dimensions in CSR product");
auto rowsD = shapeD[-(int)swapOperands];
auto colsD = shapeD.elements() / rowsD;
auto rowsS = transS ? rowsD : rowsC;
auto colsS = transS ? rowsC : rowsD;
ABORT_IF(colsD != colsC, "Inconsistent outer dimensions in CSR product");
if (swapOperands) { // make rowsX actual row dimensions again, likewise colsX
std::swap(rowsC, colsC);
std::swap(rowsD, colsD);
std::swap(rowsS, colsS);
}
// sparse arrays
auto numValues = A_values->shape().elements();
auto numOffsets = A_offsets->shape().elements() - 1; // -1 since last value is length
ABORT_IF(numOffsets != (transA ? rowsB : rowsC), "CSR offset array dimension mismatch: n={}, transA={}, rowsB={}, rowsC={}", numOffsets,transA, rowsB, rowsC);
ABORT_IF(numOffsets != (transA ? rowsB : rowsC), "CSR offset array dimension mismatch");
ABORT_IF(A_values->shape() != A_indices->shape(), "CSR values and indices must have the same size");
auto numValues = S_values->shape().elements();
auto numOffsets = S_offsets->shape().elements() - 1; // -1 since last value is length
ABORT_IF(numOffsets != rowsS, "Unexpected number of rows in CSR argument");
ABORT_IF(S_values->shape() != S_indices->shape(), "CSR values and indices must have the same size");
float alpha = 1;
// Marian uses row-major storage, but CUSPARSE/CUBLAS assume column-major.
// Hence, we compute C = spA * B as C' = B' * spA'. where B' and C' are
// column-major views on the data of B and C, and likewise, spA' is
// the CSR matrix reinterpreted as a CSC matrix.
if (transA) {
// cusparse does not support this specific version of transpose; do it explicitly
auto At_values = allocator->alloc<float>(numValues);
auto At_indices = allocator->alloc<int>(numValues);
auto At_offsets = allocator->alloc<int>(colsA + 1);
Ptr<MemoryPiece> St_values, St_indices, St_offsets;
if (transS != swapOperands) {
// Cusparse gemmi() does not support this specific version of transpose, and csrmm() is non-deterministic.
// Hence, we transpose the matrix explicitly.
// Note that gemmi() expects a CSC, while csrmm() a CSR; hence, the strange condition (transS != swapOperands) above.
St_values = allocator->alloc<float>(numValues);
St_indices = allocator->alloc<int>(numValues);
St_offsets = allocator->alloc<int>(colsS + 1);
// transpose the second argument
CUSPARSE_CHECK(cusparseScsr2csc(cusparseHandle,
/*m=*/ rowsA, // number of rows of matrix
/*n=*/ colsA, // number of columns of matrix
/*m=*/ rowsS, // number of rows of matrix
/*n=*/ colsS, // number of columns of matrix
/*nnz=*/ (int)numValues,
/*csrcVal=*/ A_values->data<float>(), // second arg
/*csrcRowPtr=*/ (int*)A_offsets->data<IndexType>(),
/*csrcColInd=*/ (int*)A_indices->data<IndexType>(),
/*cscVal=*/ At_values->data<float>(), // transposed version goes here
/*cscRowInd=*/ At_indices->data<int>(),
/*cscColPtr=*/ At_offsets->data<int>(),
/*csrcVal=*/ S_values ->data<float>(),
/*csrcRowPtr=*/ (int*)S_offsets->data<IndexType>(),
/*csrcColInd=*/ (int*)S_indices->data<IndexType>(),
/*cscVal=*/ St_values ->data<float>(), // transposed version goes here
/*cscRowInd=*/ St_indices->data<int>(),
/*cscColPtr=*/ St_offsets->data<int>(),
/*copyValues=*/ CUSPARSE_ACTION_NUMERIC,
/*idxBase=*/ CUSPARSE_INDEX_BASE_ZERO));
CUSPARSE_CHECK(cusparseSgemmi(cusparseHandle,
/*m=*/ colsB, // #rows of A = #cols of row-major B
/*n=*/ rowsC, // #cols of B and C = #rows of row-major C
/*k=*/ rowsB, // #cols of A = #rows of row-major B
std::swap(rowsS, colsS); // these variables now represent the dims of the explicitly transposed object
}
if (swapOperands) {
// C = D x S for row-major matrices
// Implemented via cusparse as C' = S' x D' ("csrmm") where C' and D' are column-major,
// and S' is CSR (if not transS then we make a transposed copy).
cusparseMatDescr_t descrA;
CUSPARSE_CHECK(cusparseCreateMatDescr(&descrA));
cusparseSetMatType (descrA, CUSPARSE_MATRIX_TYPE_GENERAL);
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO);
CUSPARSE_CHECK(cusparseScsrmm(cusparseHandle,
CUSPARSE_OPERATION_NON_TRANSPOSE, // (we explicitly transposed above)
/*m=*/ rowsS, // #rows of first (CSR) factor (the transpose was done explicitly)
/*n=*/ rowsC, // #cols of second (col-major) factor and (col-major) result = #rows of row-major C
/*k=*/ colsS, // #cols of first (CSR) factor
/*nnz=*/ (int)numValues,
&alpha,
/*A=*/ B->data(),
/*lda=*/ colsB, // stride
/*cscValB=*/ At_values->data<float>(), // second arg, transposed
/*cscRowPtrB=*/ At_offsets->data<int>(),
/*cscColIndB=*/ At_indices->data<int>(),
&alpha, descrA,
/*csrValA=*/ St_values ? St_values ->data<float>() : S_values ->data<float>(),
/*csrRowPtrA=*/ St_offsets ? St_offsets->data<int>() : (int*)S_offsets->data<IndexType>(),
/*csrColIndA=*/ St_indices ? St_indices->data<int>() : (int*)S_indices->data<IndexType>(),
D->data(),
/*ldb=*/ colsD, // stride
&beta,
C->data(),
/*ldc=*/ colsC)); // stride
allocator->free(At_values);
allocator->free(At_indices);
allocator->free(At_offsets);
cusparseDestroyMatDescr(descrA);
}
else {
// C = S x D for row-major matrices
// Implemented via cusparse as C' = D' x S' ("gemmi") where C' and D' are column-major.
CUSPARSE_CHECK(cusparseSgemmi(cusparseHandle,
/*m=*/ colsB, // #rows of A = #cols of row-major B
/*n=*/ rowsC, // #cols of B and C = #rows of row-major C
/*k=*/ rowsB, // #cols of A = #rows of row-major B
/*m=*/ colsD, // #rows of first (col-major) factor = #cols of row-major D
/*n=*/ rowsC, // #cols of second (CSC) factor and (col-major) result = #rows of row-major C
/*k=*/ rowsD, // #cols of first (col-major) factor = #rows of row-major D
/*nnz=*/ (int)numValues,
&alpha,
/*A=*/ B->data(),
/*lda=*/ colsB, // stride
/*cscValB=*/ A_values->data<float>(), // second arg
/*cscRowPtrB=*/ (int*)A_offsets->data<IndexType>(),
/*cscColIndB=*/ (int*)A_indices->data<IndexType>(),
/*A=*/ D->data(),
/*lda=*/ colsD, // stride
/*cscValB=*/ St_values ? St_values ->data<float>() : S_values ->data<float>(),
/*cscRowPtrB=*/ St_offsets ? St_offsets->data<int>() : (int*)S_offsets->data<IndexType>(),
/*cscColIndB=*/ St_indices ? St_indices->data<int>() : (int*)S_indices->data<IndexType>(),
&beta,
C->data(),
/*ldc=*/ colsC)); // stride
}
#if 0
// Incorrect code that assumes col-major matrices. Reuse that later for dense x sparse.
cusparseMatDescr_t descrA;
CUSPARSE_CHECK(cusparseCreateMatDescr(&descrA));
cusparseSetMatType (descrA, CUSPARSE_MATRIX_TYPE_GENERAL);
cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO);
CUSPARSE_CHECK(cusparseScsrmm(cusparseHandle,
transA ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
/*m=*/ rowsA, // #rows of sparse A
/*n=*/ colsB, // #cols of dense B and C
/*k=*/ colsA, // #cols of sparse A
/*nnz=*/ (int)numValues,
&alpha, descrA,
/*csrValA=*/ A_values->data<float>(),
/*csrRowPtrA=*/ (int*)A_offsets->data<IndexType>(),
/*csrColIndA=*/ (int*)A_indices->data<IndexType>(),
B->data(),
/*ldb=*/ rowsB,
&beta,
C->data(),
/*ldc=*/ rowsC));
cusparseDestroyMatDescr(descrA);
#endif
if(St_values ) allocator->free(St_values );
if(St_indices) allocator->free(St_indices);
if(St_offsets) allocator->free(St_offsets);
}
} // namespace gpu

View File

@ -40,6 +40,7 @@ void CSRProd(marian::Tensor C,
const marian::Tensor& A_offsets,
const marian::Tensor& B,
bool transA,
bool swapOperands,
float beta = 0);
} // namespace gpu
} // namespace marian

View File

@ -77,7 +77,7 @@ void Reduce(Functor functor, marian::Tensor out, Tensors... tensors) {
// clang-format off
DISPATCH7(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float)
DISPATCH8(ProdBatched, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
DISPATCH8(CSRProd, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, float)
DISPATCH9(CSRProd, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float)
DISPATCH2(Softmax, marian::Tensor, marian::Tensor)
DISPATCH3(SoftmaxGrad, marian::Tensor, marian::Tensor, marian::Tensor)

View File

@ -318,15 +318,16 @@ void tests(DeviceType device) {
auto B = graph->param("B", {3, 2}, inits::from_vector(vB));
auto C = dot(A, B);
// CSR dot product
std::vector<float> vS({1, 0, 0, 1,
// CSR dot product, tested against dense product on the same values
std::vector<float> vS({1, 0, 0, 1, // sparse
0, 0, 1, 1.5});
std::vector<float> vR({1, 2, 3, 1.2, 5.6,
std::vector<float> vD({1, 2, 3, 1.2, 5.6, // dense
4, 5, 6, 2.3, 6.7,
7, 8, 9, 3.4, 7.8,
1, 1, 2, 4.5, 8.9});
auto S = graph->param("S", { 2, 4 }, inits::from_vector(vS));
auto R = graph->param("R", { 4, 5 }, inits::from_vector(vR));
auto S = graph->param("S", { 2, 4 }, inits::from_vector(vS));
auto D = graph->param("D", { 4, 5 }, inits::from_vector(vD));
auto DT = graph->param("DT", { 5, 4 }, inits::from_vector(vD)); // example matrix with transposed dimensions
std::vector<float> SV; // create CSR version of S
std::vector<IndexType> SI, SO;
SO.push_back((IndexType)SI.size());
@ -340,37 +341,54 @@ void tests(DeviceType device) {
}
SO.push_back((IndexType)SI.size());
}
auto SxRs = csr_dot(
auto SxDd = dot(S, D);
auto STxSxDd = dot(S, SxDd, /*transA=*/true);
auto SxDs = csr_dot( // sparse x dense
S->shape(),
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
R);
auto SxRd = dot(S, R);
auto STxRs = csr_dot( // and transpose; use result of previous since dimensions match
D);
auto STxSxDs = csr_dot( // transpose(sparse) x dense; we use result of previous since dimensions match
S->shape(),
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
SxRd, /*transA=*/true);
auto STxRd = dot(S, SxRd, /*transA=*/true);
SxDd, /*transS=*/true);
auto DTxSTd = dot(DT, S, /*transA=*/false, /*transB=*/true);
auto DTxSTxSd = dot(DTxSTd, S);
auto DTxSTs = dot_csr( // dense x sparse
DT,
S->shape(),
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
/*transS=*/true);
auto DTxSTxSs = dot_csr( // dense x transpose(sparse)
DTxSTd,
S->shape(),
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32));
CHECK(C->shape() == Shape({2, 2, 2}));
CHECK(SxRs->shape() == SxRd->shape());
CHECK(STxRs->shape() == STxRd->shape());
CHECK(SxDs->shape() == SxDd->shape());
CHECK(STxSxDs->shape() == STxSxDd->shape());
CHECK(DTxSTs->shape() == DTxSTd->shape());
CHECK(DTxSTxSs->shape() == DTxSTxSd->shape());
graph->forward();
C->val()->get(values);
CHECK(values == vC);
SxRd->val()->get(values2); // dense
SxRs->val()->get(values); // sparse
CHECK(values == values2); // must be the same
STxRd->val()->get(values2);
STxRs->val()->get(values);
CHECK(values == values2);
// dense and sparse operation results must be the same
SxDd ->val()->get(values2); SxDs ->val()->get(values); CHECK(values == values2);
STxSxDd ->val()->get(values2); STxSxDs ->val()->get(values); CHECK(values == values2);
DTxSTd ->val()->get(values2); DTxSTs ->val()->get(values); CHECK(values == values2);
DTxSTxSd->val()->get(values2); DTxSTxSs->val()->get(values); CHECK(values == values2);
}
SECTION("affine transformation") {

View File

@ -1089,7 +1089,7 @@
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
</None>
<None Include="..\src\tensors\gpu\element.inc" />
<None Include="..\src\tensors\gpu\prod.cu">
<None Include="..\src\tensors\gpu\prod.cpp">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
</None>
<None Include="..\src\tensors\gpu\sparse.cu">

View File

@ -1714,9 +1714,6 @@
<None Include="..\src\tensors\gpu\element.inc">
<Filter>tensors\gpu</Filter>
</None>
<None Include="..\src\tensors\gpu\prod.cu">
<Filter>tensors\gpu</Filter>
</None>
<None Include="..\src\tensors\gpu\sparse.cu">
<Filter>tensors\gpu</Filter>
</None>
@ -1852,6 +1849,9 @@
<None Include="..\src\examples\README.md">
<Filter>examples</Filter>
</None>
<None Include="..\src\tensors\gpu\prod.cpp">
<Filter>tensors\gpu</Filter>
</None>
</ItemGroup>
<ItemGroup>
<Text Include="..\src\3rd_party\sentencepiece\src\CMakeLists.txt">