This commit is contained in:
Andre Martins 2016-09-22 11:01:32 +01:00
commit 55fa57c762
13 changed files with 263 additions and 163 deletions

View File

@ -16,13 +16,13 @@ cuda_add_library(marian_lib
target_link_libraries(marian_lib)
cuda_add_executable(
marian
test.cu
dropout_benchmark
dropout_benchmark.cu
)
cuda_add_executable(
test_cudnn
cudnn.cu
softmax_benchmark
softmax_benchmark.cu
)
cuda_add_executable(
@ -45,14 +45,14 @@ cuda_add_executable(
test_nodes.cu
)
target_link_libraries(marian marian_lib)
target_link_libraries(test_cudnn marian_lib)
target_link_libraries(dropout_benchmark marian_lib)
target_link_libraries(softmax_benchmark marian_lib)
target_link_libraries(mnist_benchmark marian_lib)
target_link_libraries(validate_mnist_batch marian_lib)
target_link_libraries(validate_encoder_decoder marian_lib)
target_link_libraries(test_nodes marian_lib)
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder test_nodes test_cudnn)
foreach(exec dropout_benchmark mnist_benchmark softmax_benchmark validate_mnist_batch validate_encoder_decoder test_nodes )
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn curand)
cuda_add_cublas_to_target(${exec})
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")

View File

@ -1,53 +0,0 @@
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <cudnn.h>
#include <boost/timer/timer.hpp>
#include "tensor.h"
#include "tensor_operators.h"
#include "param_initializers.h"
using namespace marian;
int main() {
int d = 4;
Tensor in({d, d});
Tensor out({d, d});
Tensor adj({d, d}, 1);
auto f = uniform(-5, 5);
f(in);
std::cerr << in.Debug() << std::endl;
{
boost::timer::cpu_timer timer;
for(int i = 0; i < 1; ++i) {
Tensor grad({d, d});
CudnnLogSoftmax(out, in);
CudnnLogSoftmaxGrad(grad, adj, in);
std::cerr << in.Debug() << std::endl;
std::cerr << adj.Debug() << std::endl;
std::cerr << grad.Debug() << std::endl;
}
std::cerr << timer.format(5, "%ws") << std::endl;
}
{
boost::timer::cpu_timer timer;
for(int i = 0; i < 1; ++i) {
Tensor grad({d, d});
CudnnLogSoftmax(out, in);
LogSoftmaxGrad(grad, adj, in);
std::cerr << in.Debug() << std::endl;
std::cerr << adj.Debug() << std::endl;
std::cerr << grad.Debug() << std::endl;
}
std::cerr << timer.format(5, "%ws") << std::endl;
}
return 0;
}

View File

@ -32,7 +32,6 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
layers.emplace_back(dropout(x, value=0.2));
}
else {
//layers.emplace_back(reluplus(dot(layers.back(), weights.back()), biases.back()));
layers.emplace_back(dropout(relu(dot(layers.back(), weights.back()) + biases.back()), value=0.5));
}
@ -45,8 +44,7 @@ ExpressionGraph build_graph(const std::vector<int>& dims) {
auto scores = named(dot(layers.back(), weights.back()) + biases.back(),
"scores");
//auto cost = mean(cross_entropy(scores, y), axis=0);
auto cost = mean(-sum(y * logsoftmax(scores), axis=1), axis=0);
auto cost = mean(cross_entropy(scores, y), axis=0);
auto costreg = named(
cost, "cost"
);
@ -115,7 +113,10 @@ int main(int argc, char** argv) {
std::cerr << "Done." << std::endl;
ExpressionGraph g = build_graph({IMAGE_SIZE, 2048, 2048, LABEL_SIZE});
//std::cout << g.graphviz() << std::endl;
std::ofstream viz("mnist_benchmark.dot");
viz << g.graphviz() << std::endl;
viz.close();
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
Tensor yt({BATCH_SIZE, LABEL_SIZE});

View File

@ -12,96 +12,131 @@ void Node::calc_numeric_grad(
{
using namespace std;
size_t inputSize = GetTotalSize(input.shape());
size_t valSize = GetTotalSize(val_.shape());
size_t inputSize = GetTotalSize(input.shape());
size_t valSize = GetTotalSize(val_.shape());
UTIL_THROW_IF2(inputSize != GetTotalSize(grad.shape()),
"inputSize != gradSize:" << inputSize << "!=" << GetTotalSize(grad.shape()));
UTIL_THROW_IF2(valSize != GetTotalSize(adj_.shape()),
"valSize != adjSize :" << valSize << "!=" << GetTotalSize(adj_.shape()));
UTIL_THROW_IF2(inputSize != GetTotalSize(grad.shape()),
"inputSize != gradSize:" << inputSize << "!=" << GetTotalSize(grad.shape()));
UTIL_THROW_IF2(valSize != GetTotalSize(adj_.shape()),
"valSize != adjSize :" << valSize << "!=" << GetTotalSize(adj_.shape()));
cerr << "sizes: "
<< Debug(input.shape())<< "=" << inputSize << " "
<< Debug(val_.shape()) << "=" << valSize
<< endl;
cerr << "inputSize=grad=" << Debug(input.shape())<< "=" << inputSize << " "
<< "valSize=adj_=" << Debug(val_.shape()) << "=" << valSize
<< endl;
//cerr << "input=" << input.Debug() << endl;
//cerr << "input=" << input.Debug() << endl;
//cerr << "adj_=" << adj_.Debug() << endl;
std::vector<float> origGrad(inputSize);
thrust::copy(grad.begin(), grad.end(), origGrad.begin());
cerr << "origGrad=" << grad.Debug() << endl;
//output("diffGrad", diffGrad);
std::vector<float> origGrad(inputSize);
thrust::copy(grad.begin(), grad.end(), origGrad.begin());
cerr << "origGrad=" << grad.Debug() << endl;
//output("diffGrad", diffGrad);
//output("prevCalcGrad", prevCalcGrad.begin(), prevCalcGrad.end());
//output("prevCalcGrad", prevCalcGrad.begin(), prevCalcGrad.end());
std::vector<float> inputVec(inputSize);
thrust::copy(input.begin(), input.end(), inputVec.begin());
//output("inputVec", inputVec);
std::vector<float> inputVec(inputSize);
thrust::copy(input.begin(), input.end(), inputVec.begin());
//output("inputVec", inputVec);
std::vector<float> newVal(inputSize, 0);
std::vector<float> newVal(inputSize, 0);
// LOOP thru each element in input & add delta
for (size_t inputInd = 0; inputInd < inputSize; ++inputInd) {
inputVec[inputInd] += delta;
thrust::copy(inputVec.begin(), inputVec.end(), input.begin());
forward();
for (size_t i = 0; i < valSize; ++i) {
newVal[inputInd] += val_[i];
}
inputVec[inputInd] -= delta;
}
// orig value
// LOOP thru each element in input & add delta
for (size_t inputInd = 0; inputInd < inputSize; ++inputInd) {
inputVec[inputInd] += delta;
thrust::copy(inputVec.begin(), inputVec.end(), input.begin());
//output("input", input.begin(), input.end());
forward();
Float sumValOrig = 0;
for (size_t i = 0; i < valSize; ++i) {
sumValOrig += val_[i];
newVal[inputInd] += val_[i];
}
//output("val_", val_.begin(), val_.end());
//output("newVal", newVal.begin(), newVal.end());
inputVec[inputInd] -= delta;
}
// calc gradient
//cerr << "adj_=" << adj_.Debug() << endl;
std::vector<float> adjVec(valSize);
thrust::copy(adj_.begin(), adj_.end(), adjVec.begin());
// orig value
thrust::copy(inputVec.begin(), inputVec.end(), input.begin());
forward();
std::vector<float> numericalGrad(inputSize);
for (size_t i = 0; i < numericalGrad.size(); ++i) {
numericalGrad[i] = (adjVec[i] * (newVal[i] - sumValOrig) / delta);
numericalGrad[i] += prevCalcGrad[i];
}
float sumValOrig = 0;
for (size_t i = 0; i < valSize; ++i) {
sumValOrig += val_[i];
}
// set grad results
thrust::copy(numericalGrad.begin(), numericalGrad.end(), grad.begin());
cerr << "numericalGrad=" << grad.Debug() << endl;
//output("numericalGrad", numericalGrad);
//output("newVal", newVal.begin(), newVal.end());
// print out diff between origGrad and numericalGrad
std::vector<float> diff(inputSize);
// calc gradient
//cerr << "adj_=" << adj_.Debug() << endl;
std::vector<float> adjVec(valSize);
thrust::copy(adj_.begin(), adj_.end(), adjVec.begin());
for (size_t i = 0; i < diff.size(); ++i) {
diff[i] = (origGrad[i] - numericalGrad[i]) ;
}
output("diff", diff.begin(), diff.end());
std::vector<float> numericalGrad(inputSize);
for (size_t i = 0; i < numericalGrad.size(); ++i) {
numericalGrad[i] = (newVal[i] - sumValOrig) / delta;
}
// put back origGrad
thrust::copy(origGrad.begin(), origGrad.end(), grad.begin());
broadcast(numericalGrad, adjVec);
//std::cerr << "broadcast size=" << numericalGrad.size() << " " << adjVec.size() << std::endl;
//output("adjVec=", adjVec.begin(), adjVec.end());
for (size_t i = 0; i < numericalGrad.size(); ++i) {
numericalGrad[i] *= adjVec[i];
numericalGrad[i] += prevCalcGrad[i];
}
//output("prevCalcGrad=", prevCalcGrad.begin(), prevCalcGrad.end());
//output("adjVec=", adjVec.begin(), adjVec.end());
// set grad results
thrust::copy(numericalGrad.begin(), numericalGrad.end(), grad.begin());
cerr << "numericalGrad=" << grad.Debug() << endl;
//output("numericalGrad", numericalGrad);
// print out diff between origGrad and numericalGrad
std::vector<float> diff(inputSize);
for (size_t i = 0; i < origGrad.size(); ++i) {
diff[i] = origGrad[i] - numericalGrad[i];
}
cerr << "L2-norm of difference=" << L2Norm(diff) << endl << endl;
// put back origGrad
thrust::copy(origGrad.begin(), origGrad.end(), grad.begin());
}
float Node::L2Norm(const std::vector<float> &vec) const
{
float ret = 0;
for (size_t i = 0; i < vec.size(); ++i) {
ret += vec[i] * vec[i];
}
return sqrt(ret);
}
std::vector<float> Node::StoreTensorInVec(Tensor tensor)
{
size_t totSize = GetTotalSize(tensor.shape());
std::vector<float> vec(totSize);
thrust::copy(tensor.begin(), tensor.end(), vec.begin());
return vec;
size_t totSize = GetTotalSize(tensor.shape());
std::vector<float> vec(totSize);
thrust::copy(tensor.begin(), tensor.end(), vec.begin());
return vec;
}
void Node::broadcast(const std::vector<float> &largeVec, std::vector<float> &smallVec)
{
size_t largeSize = largeVec.size();
size_t smallSize = smallVec.size();
UTIL_THROW_IF2(largeSize < smallSize,
"largeSize < smallSize:" << largeSize << "<" << smallSize);
UTIL_THROW_IF2(largeSize % smallSize,
"largeSize % smallSize != 0:" << largeSize << " " << smallSize);
smallVec.resize(largeSize);
for (size_t i = smallSize; i < largeSize; i += smallSize) {
std::copy(smallVec.begin(), smallVec.begin() + smallSize, smallVec.begin() + i);
}
}
}

View File

@ -129,7 +129,8 @@ class Node : public Chainable<Tensor>,
Tensor grad,
const std::vector<float> &prevCalcGrad
);
void broadcast(const std::vector<float> &largeVec, std::vector<float> &smallVec);
float L2Norm(const std::vector<float> &vec) const;
};

View File

@ -15,7 +15,7 @@ struct BinaryNodeOp : public Node {
void backward_debug(Float delta) {
using namespace std;
cerr << "BinaryNodeOp::" << typeid(*this).name() << "::backward_numeric()" << endl;
cerr << "BinaryNodeOp::" << typeid(*this).name() << "::backward_debug()" << endl;
std::vector<float> preCalcGradA = StoreTensorInVec(a_->grad());
//output("preCalcGradA", preCalcGradA);
@ -73,7 +73,7 @@ struct DotNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("×")
ss << "\"" << this << "\" [shape=\"box\", label=" << label("")
<< ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
@ -185,7 +185,7 @@ struct MultNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("")
ss << "\"" << this << "\" [shape=\"box\", label=" << label("x")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;

View File

@ -132,7 +132,7 @@ struct DropoutNodeOp : public UnaryNodeOp {
if(!mask_)
mask_.allocate(val_.shape());
auto f = [] __device__ (float& mask, float drop) {
return mask = drop;
};
@ -205,8 +205,7 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp {
// Based on the description for softmax, we have logsoftmax:
// J * dy = dy - avg*1
// where avg = exp(p)'*dy and p is the softmax output (probabilities).
CudnnLogSoftmaxGrad(a_->grad(), adj_, val_);
//LogSoftmaxGrad(a_->grad(), adj_, val_);
LogSoftmaxGrad(a_->grad(), adj_, val_);
}
virtual std::string graphviz() {

81
src/softmax_benchmark.cu Normal file
View File

@ -0,0 +1,81 @@
#include <stdlib.h>
#include <stdio.h>
#include <time.h>
#include <cudnn.h>
#include <boost/timer/timer.hpp>
#include "tensor.h"
#include "tensor_operators.h"
#include "param_initializers.h"
using namespace marian;
template <class F>
void testForward(F f, size_t l,
const Shape& shape,
const std::string& desc) {
Tensor in(shape);
Tensor out(shape);
uniform(-5, 5)(in);
std::cout << desc << ": " << std::flush;
boost::timer::cpu_timer timer;
for(int i = 0; i < l; ++i) {
f(out, in);
if(i % 100 == 0)
std::cout << "." << std::flush;
}
std::cout << timer.format(5, "%ws") << std::endl;
}
template <class F>
void testBackward(F f, size_t l,
const Shape& shape,
const std::string& desc) {
Tensor in(shape);
Tensor adj(shape, 1);
Tensor grad(shape);
uniform(-5, 5)(in);
std::cout << desc << ": " << std::flush;
boost::timer::cpu_timer timer;
for(int i = 0; i < l; ++i) {
f(grad, adj, in);
if(i % 100 == 0)
std::cout << "." << std::flush;
}
std::cout << timer.format(5, "%ws") << std::endl;
}
int main() {
int l = 1000;
std::vector<Shape> shapes = {
{1000, 1000},
{80, 50000},
{50000, 80},
};
for(auto& shape : shapes) {
std::cout << "Testing shape: " << shape[0] << "x" << shape[1] << std::endl << std::endl;
std::cout << "Softmax forward" << std::endl;
testForward(CudnnSoftmax, l, shape, "CuDNN ");
testForward(Softmax, l, shape, "Marian");
std::cout << std::endl;
std::cout << "Softmax backward" << std::endl;
testBackward(CudnnSoftmaxGrad, l, shape, "CuDNN ");
testBackward(SoftmaxGrad, l, shape, "Marian");
std::cout << std::endl;
std::cout << "Log-softmax backward" << std::endl;
testBackward(CudnnLogSoftmaxGrad, l, shape, "CuDNN ");
testBackward(LogSoftmaxGrad, l, shape, "Marian");
std::cout << std::endl;
}
return 0;
}

View File

@ -99,18 +99,20 @@ void CudnnLogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
cudaDeviceSynchronize();
}
__global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
__global__ void gSubtractMax(float* out, const float* in,
size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if (j < rows) {
extern __shared__ float _share[];
float* _max = _share + blockDim.x;
float* sp = out + j * cols;
_max[threadIdx.x] = sp[threadIdx.x];
const float* inRow = in + j * cols;
float* outRow = out + j * cols;
_max[threadIdx.x] = inRow[threadIdx.x];
for(int tid = 1; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if (id < cols) {
if (sp[id] > _max[threadIdx.x]) _max[threadIdx.x] = sp[id];
if (in[id] > _max[threadIdx.x]) _max[threadIdx.x] = inRow[id];
}
}
__syncthreads();
@ -129,23 +131,24 @@ __global__ void gSubtractMax(float* out, size_t rows, size_t cols) {
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
sp[id] -= _max[0];
outRow[id] = inRow[id] - _max[0];
}
}
}
}
void SubtractMax(Tensor* Out) {
void SubtractMax(Tensor out, Tensor in) {
// Out is a m-by-k matrix, passed as input.
// The max element of each row of Out is computed and subtracted from Out.
// Out is both input and output.
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
size_t m = out.shape()[0];
size_t k = out.shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
gSubtractMax<<<blocks, threads, shared>>>(out.data(),
in.data(), m, k);
cudaStreamSynchronize(0);
}
@ -183,17 +186,18 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
}
}
void Softmax(Tensor* Out) {
size_t m = Out->shape()[0];
size_t k = Out->shape()[1];
void Softmax(Tensor out, Tensor in) {
size_t m = out.shape()[0];
size_t k = out.shape()[1];
int blocks = std::min(MAX_BLOCKS, (int) m);
int threads = std::min(MAX_THREADS, (int) k);
int shared = sizeof(float) * threads * 2;
// Subtract the max rowwise for numerical stability (safe softmax).
gSubtractMax<<<blocks, threads, shared>>>(Out->data(), m, k);
gSubtractMax<<<blocks, threads, shared>>>(out.data(),
in.data(), m, k);
cudaStreamSynchronize(0);
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
gSoftMax<<<blocks, threads, shared>>>(out.data(), m, k);
cudaStreamSynchronize(0);
}
@ -267,7 +271,7 @@ __global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
_sum[threadIdx.x] += expf(valRow[id]) * adjRow[id]; // exp because we chached logsoftmax
_sum[threadIdx.x] += adjRow[id];
}
}
__syncthreads();
@ -283,7 +287,7 @@ __global__ void gLogSoftmaxGrad(float* grad, const float* adj, const float* val,
for(int tid = 0; tid < cols; tid += blockDim.x){
int id = tid + threadIdx.x;
if(id < cols)
gradRow[id] += adjRow[id] - _sum[0];
gradRow[id] += adjRow[id] - (expf(valRow[id]) * _sum[0]);
}
}
}

View File

@ -153,9 +153,9 @@ void Element(Functor functor,
void ClipNorm(Tensor out, float threshold);
void SubtractMax(Tensor* Out);
void SubtractMax(Tensor out, Tensor in);
void Softmax(Tensor* Out);
void Softmax(Tensor out, Tensor in);
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val);
void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val);

View File

@ -29,14 +29,36 @@ int main(int argc, char** argv)
Expr inExpr = g.input(shape={batch_size, input_size});
Expr labelExpr = g.input(shape={batch_size, output_size});
//Expr outExpr = softmax(inExpr);
Expr outExpr = tanh(inExpr);
//Expr outExpr = - inExpr;
Expr ceExpr = cross_entropy(outExpr, labelExpr);
Expr inExpr2 = g.input(shape={batch_size, input_size});
vector<Expr> expr;
expr.emplace_back(inExpr + inExpr2);
expr.emplace_back(inExpr - expr.back());
expr.emplace_back(inExpr * expr.back());
expr.emplace_back(inExpr / expr.back());
expr.emplace_back(reluplus(inExpr, expr.back()));
//expr.emplace_back(dot(inExpr, inExpr3));
expr.emplace_back(tanh(expr.back()));
expr.emplace_back(-expr.back());
expr.emplace_back(logit(expr.back()));
expr.emplace_back(relu(expr.back()));
expr.emplace_back(log(expr.back()));
expr.emplace_back(exp(expr.back()));
expr.emplace_back(dropout(expr.back()));
//expr.emplace_back(softmax_slow(expr.back()));
expr.emplace_back(softmax(expr.back()));
Expr ceExpr = cross_entropy(expr.back(), labelExpr);
Expr cost = mean(ceExpr, axis=0);
std::cout << g.graphviz() << std::endl;
// create data
srand(0);
//srand(0);
srand(time(NULL));
std::vector<float> values(batch_size * input_size);
generate(begin(values), end(values), Rand);
@ -52,13 +74,19 @@ int main(int argc, char** argv)
inExpr = inTensor;
labelExpr = labelTensor;
// for binary expressions
std::vector<float> values2(batch_size * input_size);
generate(begin(values2), end(values2), Rand);
Tensor inTensor2({batch_size, input_size});
thrust::copy(values2.begin(), values2.end(), inTensor2.begin());
inExpr2 = inTensor2;
// train
g.forward(batch_size);
//g.backward();
g.backward_debug(0.001);
std::cout << g.graphviz() << std::endl;
/*
std::cerr << "inTensor=" << inTensor.Debug() << std::endl;

View File

@ -20,6 +20,8 @@
// SOFTWARE.
#include <chrono>
#include <iostream>
#include <fstream>
#include <boost/timer/timer.hpp>
#include "marian.h"
@ -46,7 +48,8 @@ ExpressionGraph build_graph(int source_vocabulary_size,
int num_source_tokens,
int num_target_tokens) {
std::cerr << "Building computation graph..." << std::endl;
boost::timer::cpu_timer timer;
int input_size = source_vocabulary_size;
int output_size = target_vocabulary_size;
int num_inputs = num_source_tokens;
@ -125,8 +128,7 @@ ExpressionGraph build_graph(int source_vocabulary_size,
//auto cost = named(-mean(word_cost, axis=0), "cost");
auto cost = named(mean(word_cost, axis=0), "cost");
std::cerr << "Done." << std::endl;
std::cerr << "Done in " << timer.format(5, "%ws") << std::endl;
return g;
}
@ -199,7 +201,9 @@ int main(int argc, char** argv) {
}
std::cerr << "Printing the computation graph..." << std::endl;
std::cout << graphs[0].graphviz() << std::endl;
std::ofstream viz("encoder_decoder.dot");
viz << graphs[0].graphviz() << std::endl;
viz.close();
std::cerr << "Training..." << std::endl;