softmax benchmarks

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-22 00:06:33 +02:00
parent d0b1d7fb6d
commit 7e73abb9c2
5 changed files with 90 additions and 61 deletions

View File

@ -11,43 +11,71 @@
using namespace marian;
int main() {
int d = 4;
template <class F>
void testForward(F f, size_t l,
const Shape& shape,
const std::string& desc) {
Tensor in(shape);
Tensor out(shape);
Tensor in({d, d});
Tensor out({d, d});
Tensor adj({d, d}, 1);
auto f = uniform(-5, 5);
f(in);
std::cerr << in.Debug() << std::endl;
{
boost::timer::cpu_timer timer;
for(int i = 0; i < 1; ++i) {
Tensor grad({d, d});
CudnnLogSoftmax(out, in);
CudnnLogSoftmaxGrad(grad, adj, in);
std::cerr << in.Debug() << std::endl;
std::cerr << adj.Debug() << std::endl;
std::cerr << grad.Debug() << std::endl;
}
std::cerr << timer.format(5, "%ws") << std::endl;
}
{
boost::timer::cpu_timer timer;
for(int i = 0; i < 1; ++i) {
Tensor grad({d, d});
CudnnLogSoftmax(out, in);
LogSoftmaxGrad(grad, adj, in);
std::cerr << in.Debug() << std::endl;
std::cerr << adj.Debug() << std::endl;
std::cerr << grad.Debug() << std::endl;
}
std::cerr << timer.format(5, "%ws") << std::endl;
}
uniform(-5, 5)(in);
std::cout << desc << ": ";
boost::timer::cpu_timer timer;
for(int i = 0; i < l; ++i) {
f(out, in);
if(i % 100 == 0)
std::cout << ".";
}
std::cout << timer.format(5, "%ws") << std::endl;
}
template <class F>
void testBackward(F f, size_t l,
const Shape& shape,
const std::string& desc) {
Tensor in(shape);
Tensor adj(shape, 1);
Tensor grad(shape);
uniform(-5, 5)(in);
std::cout << desc << ": ";
boost::timer::cpu_timer timer;
for(int i = 0; i < l; ++i) {
f(grad, adj, in);
if(i % 100 == 0)
std::cout << ".";
}
std::cout << timer.format(5, "%ws") << std::endl;
}
int main() {
int l = 1000;
std::vector<Shape> shapes = {
{1000, 1000},
{80, 50000},
{50000, 80},
};
for(auto& shape : shapes) {
std::cout << "Testing shape: " << shape[0] << "x" << shape[1] << std::endl << std::endl;
std::cout << "Softmax forward" << std::endl;
testForward(CudnnSoftmax, l, shape, "CuDNN ");
testForward(Softmax, l, shape, "Marian");
std::cout << std::endl;
std::cout << "Softmax backward" << std::endl;
testBackward(CudnnSoftmaxGrad, l, shape, "CuDNN ");
testBackward(SoftmaxGrad, l, shape, "Marian");
std::cout << std::endl;
std::cout << "Log-softmax backward" << std::endl;
testBackward(CudnnLogSoftmaxGrad, l, shape, "CuDNN ");
testBackward(LogSoftmaxGrad, l, shape, "Marian");
std::cout << std::endl;
}
return 0;
}

View File

@ -32,7 +32,6 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
layers.emplace_back(dropout(x, value=0.2));
}
else {
//layers.emplace_back(reluplus(dot(layers.back(), weights.back()), biases.back()));
layers.emplace_back(dropout(relu(dot(layers.back(), weights.back()) + biases.back()), value=0.5));
}
@ -45,8 +44,7 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
auto scores = named(dot(layers.back(), weights.back()) + biases.back(),
"scores");
//auto cost = mean(cross_entropy(scores, y), axis=0);
auto cost = mean(-sum(y * logsoftmax(scores), axis=1), axis=0);
auto cost = mean(cross_entropy(scores, y), axis=0);
auto costreg = named(
cost, "cost"
);

View File

@ -15,7 +15,7 @@ struct UnaryNodeOp : public Node {
void backward_debug(Float delta) {
using namespace std;
cerr << "UnaryNodeOp::" << typeid(*this).name() << "::backward_debug()" << endl;
cerr << "UnaryNodeOp::" << typeid(*this).name() << "::backward_numeric()" << endl;
std::vector<float> preCalcGradA = StoreTensorInVec(a_->grad());
//output("preCalcGradA", preCalcGradA);
@ -132,7 +132,7 @@ struct DropoutNodeOp : public UnaryNodeOp {
if(!mask_)
mask_.allocate(val_.shape());
auto f = [] __device__ (float& mask, float drop) {
return mask = drop;
};
@ -205,8 +205,7 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp {
// Based on the description for softmax, we have logsoftmax:
// J * dy = dy - avg*1
// where avg = exp(p)'*dy and p is the softmax output (probabilities).
CudnnLogSoftmaxGrad(a_->grad(), adj_, val_);
//LogSoftmaxGrad(a_->grad(), adj_, val_);
LogSoftmaxGrad(a_->grad(), adj_, val_);
}
virtual std::string graphviz() {

View File

@ -99,18 +99,20 @@ void CudnnLogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
cudaDeviceSynchronize();
}
__global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
__global__ void gSubtractMax(float* out, const float* in,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if (j < rows) {
extern __shared__ float _share[];
float* _max = _share + blockDim.x;
float* sp = out + j * cols;
_max[threadIdx.x] = sp[threadIdx.x];
const float* inRow = in + j * cols;
float* outRow = out + j * cols;
_max[threadIdx.x] = inRow[threadIdx.x];
for(int tid = 1; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if (id < cols) {
if (sp[id] > _max[threadIdx.x]) _max[threadIdx.x] = sp[id];
if (in[id] > _max[threadIdx.x]) _max[threadIdx.x] = inRow[id];
}
}
__syncthreads();
@ -129,23 +131,24 @@ __global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] -= _max[0];
outRow[id] = inRow[id] - _max[0];
}
}
}
}
void SubtractMax(Tensor* Out) {
void SubtractMax(Tensor out, Tensor in) {
// Out is a m-by-k matrix, passed as input.
// The max element of each row of Out is computed and subtracted from Out.
// Out is both input and output.
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
size_t m = out.shape()[0];
size_t k = out.shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
gSubtractMax<<<blocks, threads, shared>>>(out.data(),
in.data(), m, k);
cudaStreamSynchronize(0);
}
@ -183,17 +186,18 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
}
}
void Softmax(Tensor* Out) {
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
void Softmax(Tensor out, Tensor in) {
size_t m = out.shape()[0];
size_t k = out.shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
// Subtract the max rowwise for numerical stability (safe softmax).
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
gSubtractMax<<<blocks, threads, shared>>>(out.data(),
in.data(), m, k);
cudaStreamSynchronize(0);
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
gSoftMax<<<blocks, threads, shared>>>(out.data(), m, k);
cudaStreamSynchronize(0);
}
@ -267,7 +271,7 @@ __global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
_sum[threadIdx.x] += expf(valRow[id]) * adjRow[id]; // exp because we chached logsoftmax
_sum[threadIdx.x] += adjRow[id];
}
}
__syncthreads();
@ -283,7 +287,7 @@ __global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
gradRow[id] += adjRow[id] - _sum[0];
gradRow[id] += adjRow[id] - (expf(valRow[id]) * _sum[0]);
}
}
}

View File

@ -153,9 +153,9 @@ void Element(Functor functor,
void ClipNorm(Tensor out, float threshold);
void SubtractMax(Tensor* Out);
void SubtractMax(Tensor out, Tensor in);
void Softmax(Tensor* Out);
void Softmax(Tensor out, Tensor in);
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);