mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
address code review comments
This commit is contained in:
parent
9f129279b9
commit
d121ba4726
@ -1,4 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This script takes a Tensorflow BERT checkpoint and a model description in a JSON file and converts
|
||||
it to a Marian weight file with numpy weights and an internal YAML description.
|
||||
|
||||
This works with checkpoints from https://github.com/google-research/bert
|
||||
|
||||
Assmung a BERT checkpoint like this:
|
||||
drwxr-xr-x 2 marcinjd marcinjd 4.0K Nov 23 16:39 .
|
||||
-rw-r--r-- 1 marcinjd marcinjd 521 Nov 23 16:38 bert_config.json
|
||||
-rw-r--r-- 1 marcinjd marcinjd 682M Nov 23 16:39 bert_model.ckpt.data-00000-of-00001
|
||||
-rw-r--r-- 1 marcinjd marcinjd 8.5K Nov 23 16:39 bert_model.ckpt.index
|
||||
-rw-r--r-- 1 marcinjd marcinjd 888K Nov 23 16:39 bert_model.ckpt.meta
|
||||
-rw-r--r-- 1 marcinjd marcinjd 973K Nov 23 16:37 vocab.txt
|
||||
|
||||
usage:
|
||||
|
||||
./bert.py --bert_prefix bert_model.ckpt --bert_config bert_config.json --marian bert.npz
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
@ -110,8 +110,8 @@ protected:
|
||||
|
||||
/**
|
||||
* @brief Accumulation rule for losses
|
||||
* In the default case this would just be a sum, see SumMultiRationalLoss, but there are
|
||||
* special cases like ScaledMultiRationalLoss (scale other loses according to first label count)
|
||||
* In the default case this would just be a sum, see SumMultiRationalLoss, but there are
|
||||
* special cases like ScaledMultiRationalLoss (scale other loses according to first label count)
|
||||
* or MeanMultiRationalLoss (sum of means) where the accumulation is more complex.
|
||||
*/
|
||||
virtual Expr accumulateLoss(const RationalLoss& current) = 0;
|
||||
@ -294,7 +294,7 @@ protected:
|
||||
lossSum = sum(lossSum, axes_[i]);
|
||||
|
||||
// reduction factor tells how over how many labels we reduced in total.
|
||||
float reducedLabels = (float)loss->shape().elements() / (float)lossSum->shape().elements();
|
||||
float reducedLabels = (float)loss->shape().elements() / (float)lossSum->shape().elements();
|
||||
return RationalLoss(lossSum, reducedLabels);
|
||||
}
|
||||
|
||||
@ -331,12 +331,15 @@ protected:
|
||||
|
||||
virtual Expr compute(Expr logits, Expr labelIndices,
|
||||
Expr mask = nullptr, Expr labelWeights = nullptr) override {
|
||||
logits = atleast_3d(logits); // safeguard against 2d classifier output, adds 1 on the left, non-op.
|
||||
logits = atleast_3d(logits); // we always assuma a time and batch dimension exists.
|
||||
// for bert training or classification the time dimension is lot.
|
||||
// Here safeguard against 2d classifier output, adds 1 on the left, non-op.
|
||||
|
||||
Expr ce = cross_entropy(logits, labelIndices);
|
||||
|
||||
if(labelSmoothing_ > 0) {
|
||||
// @TODO: add this to CE kernels instead
|
||||
|
||||
|
||||
// Label smoothing (see https://arxiv.org/pdf/1512.00567.pdf, section 7)
|
||||
// We compute smoothed H(q',p) = (1 - eps) * H(q,p) + eps * H(u,p) where H(q,p) is the normal cross-entropy
|
||||
// and H(u,p) penalizes deviation of p from u, u being uniform distribution over vocab V => u_v = 1/|V|.
|
||||
|
@ -18,7 +18,6 @@ void IsNan(const Tensor in, Ptr<Allocator> allocator, bool& isNan, bool& isInf,
|
||||
ABORT("Not implemented");
|
||||
}
|
||||
|
||||
|
||||
inline float stableSigmoid(float x) {
|
||||
if(x >= 0) {
|
||||
float z = expf(-x);
|
||||
|
@ -34,9 +34,9 @@ __global__ void gIsNan(T* in, int length, bool* isNan, bool* isInf, bool zero) {
|
||||
if(index < length) {
|
||||
if(isnan((float)in[index])) {
|
||||
if(zero) in[index] = (T)0.f;
|
||||
*isNan = true;
|
||||
*isNan = true;
|
||||
}
|
||||
else if(isinf((float)in[index])) {
|
||||
else if(isinf((float)in[index])) {
|
||||
if(zero) in[index] = (T)0.f;
|
||||
*isInf = true;
|
||||
}
|
||||
@ -406,30 +406,41 @@ void TransposeNDGrad(Tensor out, Tensor in, const std::vector<int>& vAxis) {
|
||||
}
|
||||
}
|
||||
|
||||
// Computes the softmax
|
||||
// in - input tensor
|
||||
// out - output tensor
|
||||
// we compute the softmax over the the cols (last dimension)
|
||||
// rows are time, batch or beam dimensions
|
||||
// number of threads is number of cols or MAX_THREADS
|
||||
// number of blocks is number of rows or MAX_BLOCKS
|
||||
__global__ void gSoftmax(float* out,
|
||||
functional::Shape outShape,
|
||||
const float* in) {
|
||||
int rows = outShape.elements() / outShape.back();
|
||||
int cols = outShape.back();
|
||||
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
int j = bid + blockIdx.x;
|
||||
if(j < rows) {
|
||||
float* so = out + j * cols;
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) { // loop over blocks of rows
|
||||
int j = bid + blockIdx.x; // blockIdx.x - row index (within block of rows)
|
||||
if(j < rows) { // compute softmax over one row, row elements distributed over threads
|
||||
float* so = out + j * cols; // pointer to row input data
|
||||
const float* sp = in + j * cols;
|
||||
|
||||
extern __shared__ float _share[];
|
||||
|
||||
// determine max (used below to improve numeric stability)
|
||||
float* _max = _share;
|
||||
_max[threadIdx.x] = -CUDA_FLT_MAX; // mask
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
if(sp[id] > _max[threadIdx.x])
|
||||
_max[threadIdx.x] = sp[id];
|
||||
_max[threadIdx.x] = -CUDA_FLT_MAX; // [threadIdx.x = relative column index within a block of columns]
|
||||
// find max over column indices that have the same relative column index (=threadIdx.x) across all blocks of columns
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) { // loop over blocks of columns, blockDim.x = index of block of columns
|
||||
// threadIdx.x = column index within block of columns; we reduce over columns within a block, then over blocks
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols) {
|
||||
if(sp[i] > _max[threadIdx.x])
|
||||
_max[threadIdx.x] = sp[i];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// max over columns within a column block via tree reduction
|
||||
int len = blockDim.x;
|
||||
while(len != 1) {
|
||||
__syncthreads();
|
||||
@ -443,20 +454,22 @@ __global__ void gSoftmax(float* out,
|
||||
}
|
||||
__syncthreads();
|
||||
float max = _max[0];
|
||||
__syncthreads();
|
||||
|
||||
float* _sum = _share + blockDim.x;
|
||||
__syncthreads(); // @TODO: do we need this?
|
||||
|
||||
// compute denominator
|
||||
float* _sum = _share;
|
||||
_sum[threadIdx.x] = 0.0;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
float ex = __expf(sp[id] - max);
|
||||
so[id] = ex;
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols) {
|
||||
// @TODO: is it faster to cache the result of expf() in GPU RAM, or would it be faster to recompute it below?
|
||||
float ex = __expf(sp[i] - max);
|
||||
so[i] = ex;
|
||||
_sum[threadIdx.x] += ex;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// now reduce over all columns within the block
|
||||
len = blockDim.x;
|
||||
while(len != 1) {
|
||||
__syncthreads();
|
||||
@ -466,13 +479,17 @@ __global__ void gSoftmax(float* out,
|
||||
len = (len + 1) >> 1;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// produce final output data
|
||||
float sum = _sum[0];
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols) {
|
||||
so[id] = so[id] / _sum[0];
|
||||
int i = tid + threadIdx.x;
|
||||
if(i < cols) {
|
||||
so[i] = so[i] / sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
@ -484,11 +501,12 @@ void Softmax(Tensor out, Tensor in) {
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)m);
|
||||
int threads = std::min(MAX_THREADS, (int)k);
|
||||
int shared = sizeof(float) * threads * 2;
|
||||
int shared = sizeof(float) * threads;
|
||||
|
||||
gSoftmax<<<blocks, threads, shared>>>(out->data(), out->shape(), in->data());
|
||||
}
|
||||
|
||||
// @TODO: refactor to reuse code from softmax, add comments
|
||||
__global__ void gLogSoftmax(float* out,
|
||||
const functional::Shape outShape,
|
||||
const float* in) {
|
||||
@ -528,7 +546,7 @@ __global__ void gLogSoftmax(float* out,
|
||||
float max = _max[0];
|
||||
__syncthreads();
|
||||
|
||||
float* _sum = _share + blockDim.x;
|
||||
float* _sum = _share;
|
||||
|
||||
_sum[threadIdx.x] = 0.0;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
@ -553,9 +571,10 @@ __global__ void gLogSoftmax(float* out,
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
if(id < cols)
|
||||
so[id] -= __logf(_sum[0]);
|
||||
so[id] = __logf(_sum[0]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
@ -567,7 +586,7 @@ void LogSoftmax(Tensor out, Tensor in) {
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)m);
|
||||
int threads = std::min(MAX_THREADS, (int)k);
|
||||
int shared = sizeof(float) * threads * 2;
|
||||
int shared = sizeof(float) * threads;
|
||||
|
||||
gLogSoftmax<<<blocks, threads, shared>>>(
|
||||
out->data(), out->shape(), in->data());
|
||||
@ -615,9 +634,11 @@ __global__ void gSoftmaxGrad(float* grad,
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// @TODO: refactor with logsoftmax, add math
|
||||
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
|
||||
cudaSetDevice(adj->getDeviceId().no);
|
||||
// grad and val are both m-by-k matrices, passed as input.
|
||||
@ -671,6 +692,7 @@ __global__ void gLogSoftmaxGrad(float* grad,
|
||||
gradRow[id] += adjRow[id] - (expf(valRow[id]) * _sum[0]);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
@ -1179,7 +1201,7 @@ __global__ void gCrossEntropyPick(float* out,
|
||||
float max = _max[0];
|
||||
__syncthreads();
|
||||
|
||||
float* _sum = _share + blockDim.x;
|
||||
float* _sum = _share;
|
||||
_sum[threadIdx.x] = 0.0;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
@ -1206,6 +1228,7 @@ __global__ void gCrossEntropyPick(float* out,
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
@ -1223,7 +1246,7 @@ void CrossEntropyPick(Tensor out, Tensor in, Tensor indices) {
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)rows);
|
||||
int threads = std::min(MAX_THREADS, (int)cols);
|
||||
int shared = sizeof(float) * threads * 2;
|
||||
int shared = sizeof(float) * threads;
|
||||
|
||||
gCrossEntropyPick<<<blocks, threads, shared>>>(
|
||||
out->data(), out->shape(), in->data(), in->shape(), indices->data<IndexType>());
|
||||
@ -1269,7 +1292,7 @@ __global__ void gCrossEntropyPickBackward(float* out,
|
||||
float max = _max[0];
|
||||
__syncthreads();
|
||||
|
||||
float* _sum = _share + blockDim.x;
|
||||
float* _sum = _share;
|
||||
_sum[threadIdx.x] = 0.0;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
int id = tid + threadIdx.x;
|
||||
@ -1298,6 +1321,7 @@ __global__ void gCrossEntropyPickBackward(float* out,
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
@ -1311,7 +1335,7 @@ void CrossEntropyPickBackward(Tensor out, Tensor adj, Tensor a, Tensor indices)
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)rows);
|
||||
int threads = std::min(MAX_THREADS, (int)cols);
|
||||
int shared = sizeof(float) * threads * 2;
|
||||
int shared = sizeof(float) * threads;
|
||||
|
||||
gCrossEntropyPickBackward<<<blocks, threads, shared>>>(
|
||||
out->data(), out->shape(), adj->data(), a->data(), indices->data<IndexType>());
|
||||
@ -1380,8 +1404,8 @@ __global__ void gAtt(float* out,
|
||||
}
|
||||
__syncthreads();
|
||||
out[j] = _sum[0];
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
@ -1507,7 +1531,7 @@ __global__ void gLNormalization(float* out,
|
||||
float mean = _sum[0] / cols;
|
||||
__syncthreads();
|
||||
|
||||
float* _sqSum = _share + blockDim.x;
|
||||
float* _sqSum = _share;
|
||||
|
||||
_sqSum[threadIdx.x] = 0.0;
|
||||
for(int tid = 0; tid < cols; tid += blockDim.x) {
|
||||
@ -1540,6 +1564,7 @@ __global__ void gLNormalization(float* out,
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
@ -1555,7 +1580,7 @@ void LayerNormalization(Tensor out,
|
||||
|
||||
int blocks = std::min(MAX_BLOCKS, (int)rows);
|
||||
int threads = std::min(MAX_THREADS, (int)cols);
|
||||
int shared = 2 * threads * sizeof(float);
|
||||
int shared = threads * sizeof(float);
|
||||
|
||||
gLNormalization<<<blocks, threads, shared>>>(out->data(),
|
||||
in->data(),
|
||||
@ -1665,6 +1690,7 @@ __global__ void gLayerNormalizationGrad(float* gradX,
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -342,77 +342,86 @@ void tests(DeviceType device) {
|
||||
auto B = graph->param("B", {3, 2}, inits::from_vector(vB));
|
||||
auto C = dot(A, B);
|
||||
|
||||
// CSR dot product, tested against dense product on the same values
|
||||
// std::vector<float> vS({1, 0, 0, 1, // sparse
|
||||
// 0, 0, 1, 1.5});
|
||||
// std::vector<float> vD({1, 2, 3, 1.2, 5.6, // dense
|
||||
// 4, 5, 6, 2.3, 6.7,
|
||||
// 7, 8, 9, 3.4, 7.8,
|
||||
// 1, 1, 2, 4.5, 8.9});
|
||||
// auto S = graph->param("S", { 2, 4 }, inits::from_vector(vS));
|
||||
// auto D = graph->param("D", { 4, 5 }, inits::from_vector(vD));
|
||||
// auto DT = graph->param("DT", { 5, 4 }, inits::from_vector(vD)); // example matrix with transposed dimensions
|
||||
// std::vector<float> SV; // create CSR version of S
|
||||
// std::vector<IndexType> SI, SO;
|
||||
// SO.push_back((IndexType)SI.size());
|
||||
// for (IndexType i = 0; i < S->shape()[0]; i++) {
|
||||
// for (IndexType j = 0; j < S->shape()[1]; j++) {
|
||||
// auto k = 4 * i + j;
|
||||
// if (vS[k] != 0) {
|
||||
// SV.push_back(vS[k]);
|
||||
// SI.push_back(j);
|
||||
// }
|
||||
// }
|
||||
// SO.push_back((IndexType)SI.size());
|
||||
// }
|
||||
|
||||
// auto SxDd = dot(S, D);
|
||||
// auto STxSxDd = dot(S, SxDd, /*transA=*/true);
|
||||
// auto SxDs = csr_dot( // sparse x dense
|
||||
// S->shape(),
|
||||
// graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
|
||||
// graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
|
||||
// graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
|
||||
// D);
|
||||
// auto STxSxDs = csr_dot( // transpose(sparse) x dense; we use result of previous since dimensions match
|
||||
// S->shape(),
|
||||
// graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
|
||||
// graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
|
||||
// graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
|
||||
// SxDd, /*transS=*/true);
|
||||
|
||||
// auto DTxSTd = dot(DT, S, /*transA=*/false, /*transB=*/true);
|
||||
// auto DTxSTxSd = dot(DTxSTd, S);
|
||||
// auto DTxSTs = dot_csr( // dense x sparse
|
||||
// DT,
|
||||
// S->shape(),
|
||||
// graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
|
||||
// graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
|
||||
// graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
|
||||
// /*transS=*/true);
|
||||
// auto DTxSTxSs = dot_csr( // dense x transpose(sparse)
|
||||
// DTxSTd,
|
||||
// S->shape(),
|
||||
// graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
|
||||
// graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
|
||||
// graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32));
|
||||
|
||||
CHECK(C->shape() == Shape({2, 2, 2}));
|
||||
// CHECK(SxDs->shape() == SxDd->shape());
|
||||
// CHECK(STxSxDs->shape() == STxSxDd->shape());
|
||||
// CHECK(DTxSTs->shape() == DTxSTd->shape());
|
||||
// CHECK(DTxSTxSs->shape() == DTxSTxSd->shape());
|
||||
|
||||
graph->forward();
|
||||
|
||||
C->val()->get(values);
|
||||
CHECK(values == vC);
|
||||
}
|
||||
|
||||
// dense and sparse operation results must be the same
|
||||
// SxDd ->val()->get(values2); SxDs ->val()->get(values); CHECK(values == values2);
|
||||
// STxSxDd ->val()->get(values2); STxSxDs ->val()->get(values); CHECK(values == values2);
|
||||
// DTxSTd ->val()->get(values2); DTxSTs ->val()->get(values); CHECK(values == values2);
|
||||
// DTxSTxSd->val()->get(values2); DTxSTxSs->val()->get(values); CHECK(values == values2);
|
||||
if(device == DeviceType::gpu) {
|
||||
SECTION("csr-dot product") {
|
||||
graph->clear();
|
||||
values.clear();
|
||||
// CSR dot product, tested against dense product on the same values
|
||||
std::vector<float> vS({1, 0, 0, 1, // sparse
|
||||
0, 0, 1, 1.5});
|
||||
std::vector<float> vD({1, 2, 3, 1.2, 5.6, // dense
|
||||
4, 5, 6, 2.3, 6.7,
|
||||
7, 8, 9, 3.4, 7.8,
|
||||
1, 1, 2, 4.5, 8.9});
|
||||
auto S = graph->param("S", { 2, 4 }, inits::from_vector(vS));
|
||||
auto D = graph->param("D", { 4, 5 }, inits::from_vector(vD));
|
||||
auto DT = graph->param("DT", { 5, 4 }, inits::from_vector(vD)); // example matrix with transposed dimensions
|
||||
std::vector<float> SV; // create CSR version of S
|
||||
std::vector<IndexType> SI, SO;
|
||||
SO.push_back((IndexType)SI.size());
|
||||
for (IndexType i = 0; i < S->shape()[0]; i++) {
|
||||
for (IndexType j = 0; j < S->shape()[1]; j++) {
|
||||
auto k = 4 * i + j;
|
||||
if (vS[k] != 0) {
|
||||
SV.push_back(vS[k]);
|
||||
SI.push_back(j);
|
||||
}
|
||||
}
|
||||
SO.push_back((IndexType)SI.size());
|
||||
}
|
||||
|
||||
auto SxDd = dot(S, D);
|
||||
auto STxSxDd = dot(S, SxDd, /*transA=*/true);
|
||||
auto SxDs = csr_dot( // sparse x dense
|
||||
S->shape(),
|
||||
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
|
||||
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
|
||||
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
|
||||
D);
|
||||
auto STxSxDs = csr_dot( // transpose(sparse) x dense; we use result of previous since dimensions match
|
||||
S->shape(),
|
||||
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
|
||||
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
|
||||
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
|
||||
SxDd, /*transS=*/true);
|
||||
|
||||
auto DTxSTd = dot(DT, S, /*transA=*/false, /*transB=*/true);
|
||||
auto DTxSTxSd = dot(DTxSTd, S);
|
||||
auto DTxSTs = dot_csr( // dense x sparse
|
||||
DT,
|
||||
S->shape(),
|
||||
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
|
||||
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
|
||||
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
|
||||
/*transS=*/true);
|
||||
auto DTxSTxSs = dot_csr( // dense x transpose(sparse)
|
||||
DTxSTd,
|
||||
S->shape(),
|
||||
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
|
||||
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
|
||||
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32));
|
||||
|
||||
CHECK(SxDs->shape() == SxDd->shape());
|
||||
CHECK(STxSxDs->shape() == STxSxDd->shape());
|
||||
CHECK(DTxSTs->shape() == DTxSTd->shape());
|
||||
CHECK(DTxSTxSs->shape() == DTxSTxSd->shape());
|
||||
|
||||
graph->forward();
|
||||
|
||||
// dense and sparse operation results must be the same
|
||||
SxDd ->val()->get(values2); SxDs ->val()->get(values); CHECK(values == values2);
|
||||
STxSxDd ->val()->get(values2); STxSxDs ->val()->get(values); CHECK(values == values2);
|
||||
DTxSTd ->val()->get(values2); DTxSTs ->val()->get(values); CHECK(values == values2);
|
||||
DTxSTxSd->val()->get(values2); DTxSTxSs->val()->get(values); CHECK(values == values2);
|
||||
}
|
||||
}
|
||||
|
||||
SECTION("affine transformation") {
|
||||
|
Loading…
Reference in New Issue
Block a user