address code review comments

This commit is contained in:
Marcin Junczys-Dowmunt 2019-02-04 20:26:46 -08:00
parent 9f129279b9
commit d121ba4726
5 changed files with 158 additions and 103 deletions

View File

@ -1,4 +1,22 @@
#!/usr/bin/env python3
"""
This script takes a Tensorflow BERT checkpoint and a model description in a JSON file and converts
it to a Marian weight file with numpy weights and an internal YAML description.
This works with checkpoints from https://github.com/google-research/bert
Assmung a BERT checkpoint like this:
drwxr-xr-x 2 marcinjd marcinjd 4.0K Nov 23 16:39 .
-rw-r--r-- 1 marcinjd marcinjd 521 Nov 23 16:38 bert_config.json
-rw-r--r-- 1 marcinjd marcinjd 682M Nov 23 16:39 bert_model.ckpt.data-00000-of-00001
-rw-r--r-- 1 marcinjd marcinjd 8.5K Nov 23 16:39 bert_model.ckpt.index
-rw-r--r-- 1 marcinjd marcinjd 888K Nov 23 16:39 bert_model.ckpt.meta
-rw-r--r-- 1 marcinjd marcinjd 973K Nov 23 16:37 vocab.txt
usage:
./bert.py --bert_prefix bert_model.ckpt --bert_config bert_config.json --marian bert.npz
"""
import tensorflow as tf
import numpy as np

View File

@ -110,8 +110,8 @@ protected:
/**
* @brief Accumulation rule for losses
* In the default case this would just be a sum, see SumMultiRationalLoss, but there are
* special cases like ScaledMultiRationalLoss (scale other loses according to first label count)
* In the default case this would just be a sum, see SumMultiRationalLoss, but there are
* special cases like ScaledMultiRationalLoss (scale other loses according to first label count)
* or MeanMultiRationalLoss (sum of means) where the accumulation is more complex.
*/
virtual Expr accumulateLoss(const RationalLoss& current) = 0;
@ -294,7 +294,7 @@ protected:
lossSum = sum(lossSum, axes_[i]);
// reduction factor tells how over how many labels we reduced in total.
float reducedLabels = (float)loss->shape().elements() / (float)lossSum->shape().elements();
float reducedLabels = (float)loss->shape().elements() / (float)lossSum->shape().elements();
return RationalLoss(lossSum, reducedLabels);
}
@ -331,12 +331,15 @@ protected:
virtual Expr compute(Expr logits, Expr labelIndices,
Expr mask = nullptr, Expr labelWeights = nullptr) override {
logits = atleast_3d(logits); // safeguard against 2d classifier output, adds 1 on the left, non-op.
logits = atleast_3d(logits); // we always assuma a time and batch dimension exists.
// for bert training or classification the time dimension is lot.
// Here safeguard against 2d classifier output, adds 1 on the left, non-op.
Expr ce = cross_entropy(logits, labelIndices);
if(labelSmoothing_ > 0) {
// @TODO: add this to CE kernels instead
// Label smoothing (see https://arxiv.org/pdf/1512.00567.pdf, section 7)
// We compute smoothed H(q',p) = (1 - eps) * H(q,p) + eps * H(u,p) where H(q,p) is the normal cross-entropy
// and H(u,p) penalizes deviation of p from u, u being uniform distribution over vocab V => u_v = 1/|V|.

View File

@ -18,7 +18,6 @@ void IsNan(const Tensor in, Ptr<Allocator> allocator, bool& isNan, bool& isInf,
ABORT("Not implemented");
}
inline float stableSigmoid(float x) {
if(x >= 0) {
float z = expf(-x);

View File

@ -34,9 +34,9 @@ __global__ void gIsNan(T* in, int length, bool* isNan, bool* isInf, bool zero) {
if(index < length) {
if(isnan((float)in[index])) {
if(zero) in[index] = (T)0.f;
*isNan = true;
*isNan = true;
}
else if(isinf((float)in[index])) {
else if(isinf((float)in[index])) {
if(zero) in[index] = (T)0.f;
*isInf = true;
}
@ -406,30 +406,41 @@ void TransposeNDGrad(Tensor out, Tensor in, const std::vector<int>& vAxis) {
}
}
// Computes the softmax
// in - input tensor
// out - output tensor
// we compute the softmax over the the cols (last dimension)
// rows are time, batch or beam dimensions
// number of threads is number of cols or MAX_THREADS
// number of blocks is number of rows or MAX_BLOCKS
__global__ void gSoftmax(float* out,
functional::Shape outShape,
const float* in) {
int rows = outShape.elements() / outShape.back();
int cols = outShape.back();
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
float* so = out + j * cols;
for(int bid = 0; bid < rows; bid += gridDim.x) { // loop over blocks of rows
int j = bid + blockIdx.x; // blockIdx.x - row index (within block of rows)
if(j < rows) { // compute softmax over one row, row elements distributed over threads
float* so = out + j * cols; // pointer to row input data
const float* sp = in + j * cols;
extern __shared__ float _share[];
// determine max (used below to improve numeric stability)
float* _max = _share;
_max[threadIdx.x] = -CUDA_FLT_MAX; // mask
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
if(sp[id] > _max[threadIdx.x])
_max[threadIdx.x] = sp[id];
_max[threadIdx.x] = -CUDA_FLT_MAX; // [threadIdx.x = relative column index within a block of columns]
// find max over column indices that have the same relative column index (=threadIdx.x) across all blocks of columns
for(int tid = 0; tid < cols; tid += blockDim.x) { // loop over blocks of columns, blockDim.x = index of block of columns
// threadIdx.x = column index within block of columns; we reduce over columns within a block, then over blocks
int i = tid + threadIdx.x;
if(i < cols) {
if(sp[i] > _max[threadIdx.x])
_max[threadIdx.x] = sp[i];
}
}
__syncthreads();
// max over columns within a column block via tree reduction
int len = blockDim.x;
while(len != 1) {
__syncthreads();
@ -443,20 +454,22 @@ __global__ void gSoftmax(float* out,
}
__syncthreads();
float max = _max[0];
__syncthreads();
float* _sum = _share + blockDim.x;
__syncthreads(); // @TODO: do we need this?
// compute denominator
float* _sum = _share;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
float ex = __expf(sp[id] - max);
so[id] = ex;
int i = tid + threadIdx.x;
if(i < cols) {
// @TODO: is it faster to cache the result of expf() in GPU RAM, or would it be faster to recompute it below?
float ex = __expf(sp[i] - max);
so[i] = ex;
_sum[threadIdx.x] += ex;
}
}
__syncthreads();
// now reduce over all columns within the block
len = blockDim.x;
while(len != 1) {
__syncthreads();
@ -466,13 +479,17 @@ __global__ void gSoftmax(float* out,
len = (len + 1) >> 1;
}
__syncthreads();
// produce final output data
float sum = _sum[0];
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
so[id] = so[id] / _sum[0];
int i = tid + threadIdx.x;
if(i < cols) {
so[i] = so[i] / sum;
}
}
}
__syncthreads();
}
}
@ -484,11 +501,12 @@ void Softmax(Tensor out, Tensor in) {
int blocks = std::min(MAX_BLOCKS, (int)m);
int threads = std::min(MAX_THREADS, (int)k);
int shared = sizeof(float) * threads * 2;
int shared = sizeof(float) * threads;
gSoftmax<<<blocks, threads, shared>>>(out->data(), out->shape(), in->data());
}
// @TODO: refactor to reuse code from softmax, add comments
__global__ void gLogSoftmax(float* out,
const functional::Shape outShape,
const float* in) {
@ -528,7 +546,7 @@ __global__ void gLogSoftmax(float* out,
float max = _max[0];
__syncthreads();
float* _sum = _share + blockDim.x;
float* _sum = _share;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
@ -553,9 +571,10 @@ __global__ void gLogSoftmax(float* out,
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols)
so[id] -= __logf(_sum[0]);
so[id] = __logf(_sum[0]);
}
}
__syncthreads();
}
}
@ -567,7 +586,7 @@ void LogSoftmax(Tensor out, Tensor in) {
int blocks = std::min(MAX_BLOCKS, (int)m);
int threads = std::min(MAX_THREADS, (int)k);
int shared = sizeof(float) * threads * 2;
int shared = sizeof(float) * threads;
gLogSoftmax<<<blocks, threads, shared>>>(
out->data(), out->shape(), in->data());
@ -615,9 +634,11 @@ __global__ void gSoftmaxGrad(float* grad,
}
}
}
__syncthreads();
}
}
// @TODO: refactor with logsoftmax, add math
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
cudaSetDevice(adj->getDeviceId().no);
// grad and val are both m-by-k matrices, passed as input.
@ -671,6 +692,7 @@ __global__ void gLogSoftmaxGrad(float* grad,
gradRow[id] += adjRow[id] - (expf(valRow[id]) * _sum[0]);
}
}
__syncthreads();
}
}
@ -1179,7 +1201,7 @@ __global__ void gCrossEntropyPick(float* out,
float max = _max[0];
__syncthreads();
float* _sum = _share + blockDim.x;
float* _sum = _share;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
@ -1206,6 +1228,7 @@ __global__ void gCrossEntropyPick(float* out,
}
}
}
__syncthreads();
}
}
@ -1223,7 +1246,7 @@ void CrossEntropyPick(Tensor out, Tensor in, Tensor indices) {
int blocks = std::min(MAX_BLOCKS, (int)rows);
int threads = std::min(MAX_THREADS, (int)cols);
int shared = sizeof(float) * threads * 2;
int shared = sizeof(float) * threads;
gCrossEntropyPick<<<blocks, threads, shared>>>(
out->data(), out->shape(), in->data(), in->shape(), indices->data<IndexType>());
@ -1269,7 +1292,7 @@ __global__ void gCrossEntropyPickBackward(float* out,
float max = _max[0];
__syncthreads();
float* _sum = _share + blockDim.x;
float* _sum = _share;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
@ -1298,6 +1321,7 @@ __global__ void gCrossEntropyPickBackward(float* out,
}
}
}
__syncthreads();
}
}
@ -1311,7 +1335,7 @@ void CrossEntropyPickBackward(Tensor out, Tensor adj, Tensor a, Tensor indices)
int blocks = std::min(MAX_BLOCKS, (int)rows);
int threads = std::min(MAX_THREADS, (int)cols);
int shared = sizeof(float) * threads * 2;
int shared = sizeof(float) * threads;
gCrossEntropyPickBackward<<<blocks, threads, shared>>>(
out->data(), out->shape(), adj->data(), a->data(), indices->data<IndexType>());
@ -1380,8 +1404,8 @@ __global__ void gAtt(float* out,
}
__syncthreads();
out[j] = _sum[0];
__syncthreads();
}
__syncthreads();
}
}
@ -1507,7 +1531,7 @@ __global__ void gLNormalization(float* out,
float mean = _sum[0] / cols;
__syncthreads();
float* _sqSum = _share + blockDim.x;
float* _sqSum = _share;
_sqSum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
@ -1540,6 +1564,7 @@ __global__ void gLNormalization(float* out,
}
}
}
__syncthreads();
}
}
@ -1555,7 +1580,7 @@ void LayerNormalization(Tensor out,
int blocks = std::min(MAX_BLOCKS, (int)rows);
int threads = std::min(MAX_THREADS, (int)cols);
int shared = 2 * threads * sizeof(float);
int shared = threads * sizeof(float);
gLNormalization<<<blocks, threads, shared>>>(out->data(),
in->data(),
@ -1665,6 +1690,7 @@ __global__ void gLayerNormalizationGrad(float* gradX,
}
}
}
__syncthreads();
}
}

View File

@ -342,77 +342,86 @@ void tests(DeviceType device) {
auto B = graph->param("B", {3, 2}, inits::from_vector(vB));
auto C = dot(A, B);
// CSR dot product, tested against dense product on the same values
// std::vector<float> vS({1, 0, 0, 1, // sparse
// 0, 0, 1, 1.5});
// std::vector<float> vD({1, 2, 3, 1.2, 5.6, // dense
// 4, 5, 6, 2.3, 6.7,
// 7, 8, 9, 3.4, 7.8,
// 1, 1, 2, 4.5, 8.9});
// auto S = graph->param("S", { 2, 4 }, inits::from_vector(vS));
// auto D = graph->param("D", { 4, 5 }, inits::from_vector(vD));
// auto DT = graph->param("DT", { 5, 4 }, inits::from_vector(vD)); // example matrix with transposed dimensions
// std::vector<float> SV; // create CSR version of S
// std::vector<IndexType> SI, SO;
// SO.push_back((IndexType)SI.size());
// for (IndexType i = 0; i < S->shape()[0]; i++) {
// for (IndexType j = 0; j < S->shape()[1]; j++) {
// auto k = 4 * i + j;
// if (vS[k] != 0) {
// SV.push_back(vS[k]);
// SI.push_back(j);
// }
// }
// SO.push_back((IndexType)SI.size());
// }
// auto SxDd = dot(S, D);
// auto STxSxDd = dot(S, SxDd, /*transA=*/true);
// auto SxDs = csr_dot( // sparse x dense
// S->shape(),
// graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
// graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
// graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
// D);
// auto STxSxDs = csr_dot( // transpose(sparse) x dense; we use result of previous since dimensions match
// S->shape(),
// graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
// graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
// graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
// SxDd, /*transS=*/true);
// auto DTxSTd = dot(DT, S, /*transA=*/false, /*transB=*/true);
// auto DTxSTxSd = dot(DTxSTd, S);
// auto DTxSTs = dot_csr( // dense x sparse
// DT,
// S->shape(),
// graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
// graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
// graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
// /*transS=*/true);
// auto DTxSTxSs = dot_csr( // dense x transpose(sparse)
// DTxSTd,
// S->shape(),
// graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
// graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
// graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32));
CHECK(C->shape() == Shape({2, 2, 2}));
// CHECK(SxDs->shape() == SxDd->shape());
// CHECK(STxSxDs->shape() == STxSxDd->shape());
// CHECK(DTxSTs->shape() == DTxSTd->shape());
// CHECK(DTxSTxSs->shape() == DTxSTxSd->shape());
graph->forward();
C->val()->get(values);
CHECK(values == vC);
}
// dense and sparse operation results must be the same
// SxDd ->val()->get(values2); SxDs ->val()->get(values); CHECK(values == values2);
// STxSxDd ->val()->get(values2); STxSxDs ->val()->get(values); CHECK(values == values2);
// DTxSTd ->val()->get(values2); DTxSTs ->val()->get(values); CHECK(values == values2);
// DTxSTxSd->val()->get(values2); DTxSTxSs->val()->get(values); CHECK(values == values2);
if(device == DeviceType::gpu) {
SECTION("csr-dot product") {
graph->clear();
values.clear();
// CSR dot product, tested against dense product on the same values
std::vector<float> vS({1, 0, 0, 1, // sparse
0, 0, 1, 1.5});
std::vector<float> vD({1, 2, 3, 1.2, 5.6, // dense
4, 5, 6, 2.3, 6.7,
7, 8, 9, 3.4, 7.8,
1, 1, 2, 4.5, 8.9});
auto S = graph->param("S", { 2, 4 }, inits::from_vector(vS));
auto D = graph->param("D", { 4, 5 }, inits::from_vector(vD));
auto DT = graph->param("DT", { 5, 4 }, inits::from_vector(vD)); // example matrix with transposed dimensions
std::vector<float> SV; // create CSR version of S
std::vector<IndexType> SI, SO;
SO.push_back((IndexType)SI.size());
for (IndexType i = 0; i < S->shape()[0]; i++) {
for (IndexType j = 0; j < S->shape()[1]; j++) {
auto k = 4 * i + j;
if (vS[k] != 0) {
SV.push_back(vS[k]);
SI.push_back(j);
}
}
SO.push_back((IndexType)SI.size());
}
auto SxDd = dot(S, D);
auto STxSxDd = dot(S, SxDd, /*transA=*/true);
auto SxDs = csr_dot( // sparse x dense
S->shape(),
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
D);
auto STxSxDs = csr_dot( // transpose(sparse) x dense; we use result of previous since dimensions match
S->shape(),
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
SxDd, /*transS=*/true);
auto DTxSTd = dot(DT, S, /*transA=*/false, /*transB=*/true);
auto DTxSTxSd = dot(DTxSTd, S);
auto DTxSTs = dot_csr( // dense x sparse
DT,
S->shape(),
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32),
/*transS=*/true);
auto DTxSTxSs = dot_csr( // dense x transpose(sparse)
DTxSTd,
S->shape(),
graph->constant({(int)SV.size()}, inits::from_vector(SV), Type::float32),
graph->constant({(int)SI.size()}, inits::from_vector(SI), Type::uint32),
graph->constant({(int)SO.size()}, inits::from_vector(SO), Type::uint32));
CHECK(SxDs->shape() == SxDd->shape());
CHECK(STxSxDs->shape() == STxSxDd->shape());
CHECK(DTxSTs->shape() == DTxSTd->shape());
CHECK(DTxSTxSs->shape() == DTxSTxSd->shape());
graph->forward();
// dense and sparse operation results must be the same
SxDd ->val()->get(values2); SxDs ->val()->get(values); CHECK(values == values2);
STxSxDd ->val()->get(values2); STxSxDs ->val()->get(values); CHECK(values == values2);
DTxSTd ->val()->get(values2); DTxSTs ->val()->get(values); CHECK(values == values2);
DTxSTxSd->val()->get(values2); DTxSTxSs->val()->get(values); CHECK(values == values2);
}
}
SECTION("affine transformation") {