diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 365df389..cb121111 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -29,10 +29,15 @@ cuda_add_executable( validate_mnist validate_mnist.cu ) +cuda_add_executable( + validate_mnist_batch + validate_mnist_batch.cu +) target_link_libraries(validate_mnist marian_lib) +target_link_libraries(validate_mnist_batch marian_lib) -foreach(exec marian train_mnist validate_mnist) +foreach(exec marian train_mnist validate_mnist validate_mnist_batch ) target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn) cuda_add_cublas_to_target(${exec}) set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") diff --git a/src/expression_operators.h b/src/expression_operators.h index 3d42400f..878f9882 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -94,7 +94,6 @@ Expr broadcast(Shape bShape, Expr a) { "Cannot broadcast tensor dimension " << dimA << " to " << dimB); if(dimA == 1 && dimB != 1) { - std::cerr << "Broadcasting dim " << i << " from " << dimA << " to " << dimB << std::endl; if(i == 0) { Expr one = ones(keywords::shape={bShape[0], 1}); a = dot(one, a); diff --git a/src/expressions.cu b/src/expressions.cu index a95b1bef..b2ff90ba 100644 --- a/src/expressions.cu +++ b/src/expressions.cu @@ -24,22 +24,17 @@ ChainPtr Expr::node() { void Expr::forward(size_t batchSize) { UTIL_THROW_IF2(pimpl_.get() != Chainable::stack.back(), - "Trying to call forward on non-root of computation graph"); - std::cerr << "forward:" << std::endl; - + "Trying to call forward on non-root of computation graph"); for(auto&& v : Chainable::stack) { v->allocate(batchSize); } - for(auto&& v : Chainable::stack) v->forward(); } void Expr::backward() { UTIL_THROW_IF2(pimpl_.get() != Chainable::stack.back(), - "Trying to call backward on non-root of computation graph"); - std::cerr << "backward:" << std::endl; - + "Trying to call backward on non-root of computation graph"); for(auto&& v : Chainable::stack) v->set_zero_adjoint(); @@ -56,7 +51,6 @@ Expr::operator ChainPtr() { std::string Expr::Debug() const { stringstream strm; - //const Chainable &ct = *pimpl_; const Shape &shape = pimpl_->shape(); strm << marian::Debug(shape); return strm.str(); diff --git a/src/graph_operators.h b/src/graph_operators.h index bf5a3336..f231103b 100644 --- a/src/graph_operators.h +++ b/src/graph_operators.h @@ -42,7 +42,8 @@ struct ParamNode : public Node { template ParamNode(Args ...args) : Node(args...), - init_(Get>(keywords::init, [](Tensor){ })) + init_(Get>(keywords::init, [](Tensor){ })), + initialized_(false) { UTIL_THROW_IF2(!Has(keywords::shape) && !Has(keywords::lazy_shape), @@ -51,14 +52,18 @@ struct ParamNode : public Node { void forward() {} void backward() {} - + virtual void allocate(size_t batchSize) { val_.allocate(shape_); - init_(val_); + if(!initialized_) { + init_(val_); + initialized_ = true; + } } private: std::function init_; + bool initialized_; }; struct UnaryNodeOp : public Node { @@ -66,7 +71,9 @@ struct UnaryNodeOp : public Node { template UnaryNodeOp(ChainPtr a, Args ...args) - : Node(args...), a_(a) {} + : Node(keywords::shape=a->shape(), //@TODO: Check keywords? + args...), + a_(a) {} }; struct SigmoidNodeOp : public UnaryNodeOp { @@ -125,24 +132,25 @@ struct ArgmaxOp : public UnaryNodeOp { void forward() { //val_ = Argmax(a_->val(), axis_); + UTIL_THROW2("Not implemented"); } - void backward() {} + void backward() { + UTIL_THROW2("Not implemented"); + } private: int axis_; }; - +// @TODO, make this numerically safe(r): +// softmax(X) = softmax_safe(X - max(X, axis=1)) +// Probably best to do this directly in Softmax +// function. struct SoftmaxNodeOp : public UnaryNodeOp { template SoftmaxNodeOp(ChainPtr a, Args ...args) - : UnaryNodeOp(a, keywords::shape=newShape(a), - args...) { } - Shape newShape(ChainPtr a) { - Shape shape = a->shape(); - return shape; - } + : UnaryNodeOp(a, args...) { } void forward() { // B = softmax(A). @@ -164,8 +172,8 @@ struct SoftmaxNodeOp : public UnaryNodeOp { struct LogNodeOp : public UnaryNodeOp { template - LogNodeOp(Args ...args) - : UnaryNodeOp(args...) {} + LogNodeOp(ChainPtr a, Args ...args) + : UnaryNodeOp(a, args...) {} void forward() { Element(_1 = Log(_2), val_, a_->val()); @@ -180,13 +188,7 @@ struct LogNodeOp : public UnaryNodeOp { struct ExpNodeOp : public UnaryNodeOp { template ExpNodeOp(ChainPtr a, Args ...args) - : UnaryNodeOp(a, keywords::shape=newShape(a), - args...) { } - - Shape newShape(ChainPtr a) { - Shape shape = a->shape(); - return shape; - } + : UnaryNodeOp(a, args...) { } void forward() { Element(_1 = Exp(_2), val_, a_->val()); diff --git a/src/param_initializers.h b/src/param_initializers.h index ab781064..5a04a25c 100644 --- a/src/param_initializers.h +++ b/src/param_initializers.h @@ -22,7 +22,7 @@ void ones(Tensor t) { void randreal(Tensor t) { std::random_device device; std::default_random_engine engine(device()); - std::uniform_real_distribution<> dist(0, 1); + std::uniform_real_distribution<> dist(0, 0.1); auto gen = std::bind(dist, engine); std::vector vals(t.size()); diff --git a/src/sgd.h b/src/sgd.h index 298cd358..0dab8df0 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -60,8 +60,8 @@ class SGD { std::vector y(yData_.begin() + startId * numClasses_, yData_.begin() + endId * numClasses_); - xt.Load(x); - yt.Load(y); + xt.set(x); + yt.set(y); } void UpdateModel() { diff --git a/src/tensor.cu b/src/tensor.cu index 09355b21..0c3e8a3e 100644 --- a/src/tensor.cu +++ b/src/tensor.cu @@ -5,94 +5,26 @@ using namespace std; namespace marian { -inline std::vector Tokenize(const std::string& str, - const std::string& delimiters = " \t") -{ - std::vector tokens; - // Skip delimiters at beginning. - std::string::size_type lastPos = str.find_first_not_of(delimiters, 0); - // Find first "non-delimiter". - std::string::size_type pos = str.find_first_of(delimiters, lastPos); - - while (std::string::npos != pos || std::string::npos != lastPos) { - // Found a token, add it to the vector. - tokens.push_back(str.substr(lastPos, pos - lastPos)); - // Skip delimiters. Note the "not_of" - lastPos = str.find_first_not_of(delimiters, pos); - // Find next "non-delimiter" - pos = str.find_first_of(delimiters, lastPos); - } - - return tokens; -} - -//! convert string to variable of type T. Used to reading floats, int etc from files -template -T Scan(const std::string &input) -{ - std::stringstream stream(input); - T ret; - stream >> ret; - return ret; -} - -//! convert vectors of string to vectors of type T variables -template -inline std::vector Scan(const std::vector< std::string > &input) -{ - std::vector output(input.size()); - for (size_t i = 0 ; i < input.size() ; i++) { - output[i] = Scan( input[i] ); - } - return output; -} - -//! tokenise input string to vector of type T -template -inline std::vector Tokenize( const std::string &input - , const std::string& delimiters = " \t") -{ - std::vector stringVector = Tokenize(input, delimiters); - return Scan( stringVector ); -} - - -void Tensor::Load(const std::string &path) -{ - size_t totSize = GetTotalSize(pimpl_->shape()); - cerr << "totSize=" << totSize << endl; - std::vector hostData(totSize); - - fstream strm; - strm.open(path.c_str()); - - string line; - size_t ind = 0; - while ( getline (strm, line) ) - { - cerr << line << '\n'; - vector toks = Tokenize(line); - for (size_t i = 0; i < toks.size(); ++i) { - hostData[ind] = toks[i]; - } - - ++ind; - } - strm.close(); - - Load(hostData.begin(), hostData.begin()); -} - -void Tensor::Load(const std::vector& data) +void Tensor::set(const std::vector& data) { pimpl_->set(data.begin(), data.end()); } - -void Tensor::Load(const std::vector::const_iterator &begin, const std::vector::const_iterator &end) +void Tensor::set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end) { pimpl_->set(begin, end); } +Tensor& operator<<(Tensor& t, const std::vector &vec) { + t.set(vec); + return t; +} + +std::vector& operator<<(std::vector &vec, const Tensor& t) { + t.get(vec); + return vec; +} + + } diff --git a/src/tensor.h b/src/tensor.h index 0f6029d8..b13e55fe 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -12,30 +12,6 @@ namespace marian { -//struct Handles { -// //cudnnHandle_t cudnnHandle; -// //cublasHandle_t cublasHandle; -// -// //cudnnOpTensorDescriptor_t add; -// -// Handles() { -// cudnnCreate(&cudnnHandle); -// cublasCreate(&cublasHandle); -// cudnnCreateOpTensorDescriptor(&add); -// cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN); -// } -// -// ~Handles() { -// cudnnDestroy(cudnnHandle); -// cublasDestroy(cublasHandle); -// cudnnDestroyOpTensorDescriptor(add); -// } -//}; -// -//const Handles handles; - -// typedef std::vector Shape; - inline std::string Debug(const Shape &shape) { std::stringstream strm; @@ -59,18 +35,9 @@ class TensorImpl { private: Shape shape_; thrust::device_vector data_; - //cudnnTensorDescriptor_t desc_; size_t tno_; static size_t tensorCounter; - //cudnnDataType_t dataType() { - // switch(sizeof(Float)) { - // case 2: return CUDNN_DATA_HALF; - // case 8: return CUDNN_DATA_DOUBLE; - // default: return CUDNN_DATA_FLOAT; - // } - //} - public: typedef Float value_type; @@ -85,34 +52,13 @@ class TensorImpl { UTIL_THROW_IF2(shape_.size() < 1 || shape_.size() > 4, "Wrong number of dimensions: " << shape_.size()); - std::cerr << "Allocating : " << shape[0] << " " << shape[1] << std::endl; - int size = GetTotalSize(shape_); data_.resize(size, value); - //cudnnCreateTensorDescriptor(&desc_); - //switch (shape_.size()) { - // case 1: - // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - // shape_[0], 1, 1, 1); break; - // case 2: - // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - // shape_[0], shape_[1], 1, 1); break; - // case 3: - // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - // shape_[0], shape_[1], shape_[2], 1); break; - // case 4: - // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - // shape_[0], shape_[1], shape_[2], shape_[3]); break; - //} } TensorImpl(const TensorImpl&) = delete; TensorImpl(TensorImpl&&) = delete; - ~TensorImpl() { - //cudnnDestroyTensorDescriptor(desc_); - } - value_type operator[](size_t i) const { return data_[i]; } @@ -145,10 +91,6 @@ class TensorImpl { return thrust::raw_pointer_cast(data_.data()); } - //cudnnTensorDescriptor_t desc() const { - // return desc_; - //} - size_t id() const { return tno_; } @@ -158,12 +100,13 @@ class TensorImpl { } void set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end) { - size_t totSize = GetTotalSize(shape()); - //std::cerr << "tensor size=" << totSize << " vector size=" << values.size() << std::endl; - //assert(totSize == values.size()); thrust::copy(begin, end, data_.begin()); } + void get(std::vector::iterator out) { + thrust::copy(data_.begin(), data_.end(), out); + } + std::string Debug() const { std::stringstream strm; @@ -233,22 +176,18 @@ class Tensor { return pimpl_->begin(); } - auto end() -> decltype( pimpl_->begin() ) { - return pimpl_->begin(); + auto end() -> decltype( pimpl_->end() ) { + return pimpl_->end(); } - auto end() const -> decltype( pimpl_->begin() ) { - return pimpl_->begin(); + auto end() const -> decltype( pimpl_->end() ) { + return pimpl_->end(); } const Shape& shape() const { return pimpl_->shape(); } - //cudnnTensorDescriptor_t desc() const { - // return pimpl_->desc(); - //} - void set(value_type value) { pimpl_->set(value); } @@ -273,10 +212,21 @@ class Tensor { std::cerr << std::endl; } - void Load(const std::string &path); - void Load(const std::vector& data); - void Load(const std::vector::const_iterator &begin, const std::vector::const_iterator &end); + //void Load(const std::string &path); + void set(const std::vector& data); + void set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end); + void get(std::vector::iterator out) const { + pimpl_->get(out); + } + + void get(std::vector &vout) const { + pimpl_->get(vout.begin()); + } }; +Tensor& operator<<(Tensor& t, const std::vector &vec); + +std::vector& operator<<(std::vector &vec, const Tensor& t); + } diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 023aba8b..2226eb16 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -2,6 +2,7 @@ #include "marian.h" #include "mnist.h" #include "npz_converter.h" +#include "param_initializers.h" using namespace marian; using namespace keywords; @@ -12,80 +13,88 @@ int main(int argc, char** argv) { const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; - int numofdata; - + int BATCH_SIZE = 10000; + std::cerr << "Loading test set..."; - std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); - std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); - std::cerr << "\tDone." << std::endl; + std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE); + std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE); + std::cerr << "Done." << std::endl; + std::cerr << "Loading model params..."; NpzConverter converter("../scripts/test_model/model.npz"); - std::vector wData; - Shape wShape; + std::vector wData, bData; + Shape wShape, bShape; converter.Load("weights", wData, wShape); - - std::vector bData; - Shape bShape; converter.Load("bias", bData, bShape); - - auto initW = [wData](Tensor t) { - thrust::copy(wData.begin(), wData.end(), t.begin()); - }; - - auto initB = [bData](Tensor t) { - thrust::copy(bData.begin(), bData.end(), t.begin()); - }; - - std::cerr << "\tDone." << std::endl; - - - Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X"); - Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y"); - - Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW); - Expr b = param(shape={1, LABEL_SIZE}, name="b0", init=initB); + std::cerr << "Done." << std::endl; std::cerr << "Building model..."; - auto predict = softmax(dot(x, w) + b, - axis=1, name="pred"); - auto graph = -mean(sum(y * log(predict), axis=1), - axis=0, name="cost"); - std::cerr << "\tDone." << std::endl; + auto x = input(shape={whatevs, IMAGE_SIZE}); + auto y = input(shape={whatevs, LABEL_SIZE}); + + auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, + init=[wData](Tensor t) { t.set(wData); }); + auto b = param(shape={1, LABEL_SIZE}, + init=[bData](Tensor t) { t.set(bData); }); - Tensor xt({numofdata, IMAGE_SIZE}); - xt.Load(testImages); + auto probs = softmax(dot(x, w) + b, axis=1); + auto graph = -mean(sum(y * log(probs), axis=1), axis=0); - Tensor yt({numofdata, LABEL_SIZE}); - yt.Load(testLabels); - - x = xt; - y = yt; - - graph.forward(numofdata); - auto results = predict.val(); - graph.backward(); - - std::cerr << b.grad().Debug() << std::endl; + std::cerr << "Done." << std::endl; + Tensor xt({BATCH_SIZE, IMAGE_SIZE}); + Tensor yt({BATCH_SIZE, LABEL_SIZE}); + + x = xt << testImages; + y = yt << testLabels; + + graph.forward(BATCH_SIZE); + auto results = probs.val(); + std::vector resultsv(results.size()); + resultsv << results; + size_t acc = 0; for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) { size_t correct = 0; - size_t predicted = 0; + size_t probsed = 0; for (size_t j = 0; j < LABEL_SIZE; ++j) { if (testLabels[i+j]) correct = j; - if (results[i + j] > results[i + predicted]) predicted = j; + if (resultsv[i + j] > resultsv[i + probsed]) probsed = j; } - acc += (correct == predicted); - //std::cerr << "corect: " << correct << " | " << predicted << "("; - //for (size_t j = 0; j < LABEL_SIZE; ++j) { - // std::cerr << results[i+j] << " "; - //} - //std::cerr << std::endl; + acc += (correct == probsed); } - std::cerr << "ACC: " << float(acc)/numofdata << std::endl; + std::cerr << "Cost: " << graph.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl; + float eta = 0.1; + for (size_t j = 0; j < 10; ++j) { + for(size_t i = 0; i < 60; ++i) { + graph.backward(); + + auto update_rule = _1 -= eta * _2; + Element(update_rule, w.val(), w.grad()); + Element(update_rule, b.val(), b.grad()); + + graph.forward(BATCH_SIZE); + } + std::cerr << "Epoch: " << j << std::endl; + auto results = probs.val(); + std::vector resultsv(results.size()); + resultsv << results; + + size_t acc = 0; + for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) { + size_t correct = 0; + size_t probsed = 0; + for (size_t j = 0; j < LABEL_SIZE; ++j) { + if (testLabels[i+j]) correct = j; + if (resultsv[i + j] > resultsv[i + probsed]) probsed = j; + } + acc += (correct == probsed); + } + std::cerr << "Cost: " << graph.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl; + } return 0; } diff --git a/src/validate_mnist_batch.cu b/src/validate_mnist_batch.cu new file mode 100644 index 00000000..ac4e7359 --- /dev/null +++ b/src/validate_mnist_batch.cu @@ -0,0 +1,113 @@ + +#include "marian.h" +#include "mnist.h" +#include "npz_converter.h" + +using namespace marian; +using namespace keywords; + +int main(int argc, char** argv) { + + cudaSetDevice(0); + + const size_t IMAGE_SIZE = 784; + const size_t LABEL_SIZE = 10; + const size_t BATCH_SIZE = 24; + int numofdata; + + std::cerr << "Loading test set..."; + std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); + std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); + std::cerr << "\tDone." << std::endl; + + std::cerr << "Loading model params..."; + NpzConverter converter("../scripts/test_model/model.npz"); + + std::vector wData; + Shape wShape; + converter.Load("weights", wData, wShape); + + std::vector bData; + Shape bShape; + converter.Load("bias", bData, bShape); + + auto initW = [wData](Tensor t) { + t.set(wData); + }; + + auto initB = [bData](Tensor t) { + t.set(bData); + }; + + std::cerr << "\tDone." << std::endl; + + + auto x = input(shape={whatevs, IMAGE_SIZE}, name="X"); + auto y = input(shape={whatevs, LABEL_SIZE}, name="Y"); + + auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW); + auto b = param(shape={1, LABEL_SIZE}, name="b0", init=initB); + + std::cerr << "Building model..."; + auto predict = softmax(dot(x, w) + b, axis=1, name="pred"); + + std::cerr << "Done." << std::endl; + + Tensor xt({BATCH_SIZE, IMAGE_SIZE}); + + size_t acc = 0; + size_t startId = 0; + size_t endId = startId + BATCH_SIZE; + + while (endId < numofdata) { + std::vector tmp(testImages.begin() + (startId * IMAGE_SIZE), + testImages.begin() + (endId * IMAGE_SIZE)); + xt << tmp; + x = xt; + + predict.forward(BATCH_SIZE); + + std::vector results(LABEL_SIZE * BATCH_SIZE); + results << predict.val(); + + for (size_t i = 0; i < BATCH_SIZE * LABEL_SIZE; i += LABEL_SIZE) { + size_t correct = 0; + size_t predicted = 0; + for (size_t j = 0; j < LABEL_SIZE; ++j) { + if (testLabels[startId * LABEL_SIZE + i + j]) correct = j; + if (results[i + j] > results[i + predicted]) predicted = j; + } + acc += (correct == predicted); + } + + startId += BATCH_SIZE; + endId += BATCH_SIZE; + } + if (endId != numofdata) { + endId = numofdata; + if (endId - startId > 0) { + std::vector tmp(testImages.begin() + (startId * IMAGE_SIZE), + testImages.begin() + (endId * IMAGE_SIZE)); + xt << tmp; + x = xt; + + predict.forward(endId - startId); + + std::vector results(LABEL_SIZE * BATCH_SIZE); + results << predict.val(); + + for (size_t i = 0; i < (endId - startId) * LABEL_SIZE; i += LABEL_SIZE) { + size_t correct = 0; + size_t predicted = 0; + for (size_t j = 0; j < LABEL_SIZE; ++j) { + if (testLabels[startId * LABEL_SIZE + i + j]) correct = j; + if (results[i + j] > results[i + predicted]) predicted = j; + } + acc += (correct == predicted); + } + } + } + std::cerr << "ACC: " << float(acc)/numofdata << std::endl; + + return 0; +}