From c28ba2e67f62aad1d9a3acbe6a15d82ac2bb9792 Mon Sep 17 00:00:00 2001 From: Roman Grundkiewicz Date: Wed, 14 Sep 2016 18:18:58 +0200 Subject: [PATCH 01/11] add param random initializers --- src/marian.h | 1 + src/param_initializers.h | 34 ++++++++++++++++++++++++++++++++++ src/test.cu | 1 + 3 files changed, 36 insertions(+) create mode 100644 src/param_initializers.h diff --git a/src/marian.h b/src/marian.h index 8c987ccf..0876d4cd 100644 --- a/src/marian.h +++ b/src/marian.h @@ -5,4 +5,5 @@ #include "graph_operators.h" #include "expressions.h" #include "expression_operators.h" +#include "param_initializers.h" diff --git a/src/param_initializers.h b/src/param_initializers.h new file mode 100644 index 00000000..ab781064 --- /dev/null +++ b/src/param_initializers.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include +#include + +#include "tensor.h" + +namespace marian { + +void zeros(Tensor t) { + std::vector vals(t.size(), 0.0f); + thrust::copy(vals.begin(), vals.end(), t.begin()); +} + +void ones(Tensor t) { + std::vector vals(t.size(), 1.0f); + thrust::copy(vals.begin(), vals.end(), t.begin()); +} + +void randreal(Tensor t) { + std::random_device device; + std::default_random_engine engine(device()); + std::uniform_real_distribution<> dist(0, 1); + auto gen = std::bind(dist, engine); + + std::vector vals(t.size()); + std::generate(begin(vals), end(vals), gen); + + thrust::copy(vals.begin(), vals.end(), t.begin()); +} + +} // namespace marian diff --git a/src/test.cu b/src/test.cu index 0285e3a5..a86c60ee 100644 --- a/src/test.cu +++ b/src/test.cu @@ -20,6 +20,7 @@ int main(int argc, char** argv) { Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y"); Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0"); + // Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=randreal); Expr b = param(shape={1, LABEL_SIZE}, name="b0"); Expr z = dot(x, w) + b; From ab1300291e740c0850521b91fef6a09a8f43276c Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Wed, 14 Sep 2016 18:40:39 +0200 Subject: [PATCH 02/11] compile error --- src/CMakeLists.txt | 15 ++++++++------- src/npz_converter.cpp | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2aa3fdc9..11b36053 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,34 +2,35 @@ include_directories(.) add_library(libcommon OBJECT - exception.cpp cnpy/cnpy.cpp + exception.cpp + #npz_converter.cpp ) cuda_add_executable( marian - test.cu expressions.cu - tensor_operators.cu + test.cu tensor.cu + tensor_operators.cu $ ) cuda_add_executable( train_mnist - train_mnist.cu expressions.cu - tensor_operators.cu + train_mnist.cu tensor.cu + tensor_operators.cu $ ) cuda_add_executable( validate_mnist - validate_mnist.cu expressions.cu - tensor_operators.cu + validate_mnist.cu tensor.cu + tensor_operators.cu $ ) diff --git a/src/npz_converter.cpp b/src/npz_converter.cpp index 1ecbc11c..771ff11c 100644 --- a/src/npz_converter.cpp +++ b/src/npz_converter.cpp @@ -1,4 +1,4 @@ -#include "common/npz_converter.h" +#include "npz_converter.h" From ea04f8a6baf692520aeed5de56dd1cefd18df712 Mon Sep 17 00:00:00 2001 From: Maximiliana Behnke Date: Wed, 14 Sep 2016 18:56:13 +0200 Subject: [PATCH 03/11] Modify single layer training script, add 2-layer training script --- scripts/train_test_model_multi.py | 72 +++++++++++++++++++ ...st_model.py => train_test_model_single.py} | 8 +-- 2 files changed, 76 insertions(+), 4 deletions(-) create mode 100755 scripts/train_test_model_multi.py rename scripts/{train_test_model.py => train_test_model_single.py} (91%) diff --git a/scripts/train_test_model_multi.py b/scripts/train_test_model_multi.py new file mode 100755 index 00000000..67ae0131 --- /dev/null +++ b/scripts/train_test_model_multi.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python + +import sys +import os +import numpy as np +from keras.datasets import mnist +from keras.utils import np_utils +from keras.models import Sequential +from keras.layers import Dense +from keras.layers import Dropout + +def softmax(x): + return np.exp(x) / np.sum(np.exp(x), axis=1)[:, None] + + +def baseline_model(pixels_count, classes_count): + model = Sequential() + model.add(Dense(100, input_dim=pixels_count, init='normal', activation='tanh')) + model.add(Dense(classes_count, input_dim=100, init='normal', activation='softmax')) + model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) + return model + + +if __name__ == "__main__": + ### Load trainset from mnist + + (X_train, y_train), (X_test, y_test) = mnist.load_data() + + ### Flatten pictures into vectors + + pixels_count = X_train.shape[1] * X_train.shape[2] + X_train = X_train.reshape(X_train.shape[0], pixels_count).astype('float32') + print "X shape: ", X_train.shape + + X_test = X_test.reshape(X_test.shape[0], pixels_count).astype('float32') + + ### Normalize data to (0, 1) + + X_train = X_train / 255 + X_test = X_test / 255 + + ### Change classes to one hot encoding matrixes + + y_train = np_utils.to_categorical(y_train) + classes_count = y_train.shape[1] + print "Y shape: ", y_train.shape + + y_test = np_utils.to_categorical(y_test) + + # Train weight matrix + + # Build the model + model = baseline_model(pixels_count, classes_count) + # Fit the model + model.fit(X_train, y_train, validation_data=(X_test, y_test), nb_epoch=10, batch_size=200, verbose=2) + # Final evaluation of the model + scores = model.evaluate(X_test, y_test, verbose=0) + print("Baseline Error: %.2f%%" % (100-scores[1]*100)) + + ### Weight and bias matrixes - we extract them from the model + + # weights_ones = np.ones((pixels_count, classes_count)) + # print weights_ones.shape + + weights1, bias1, weights2, bias2 = model.get_weights() + ### Save model to npz files + if not os.path.exists("test_model_multi"): + os.makedirs("test_model_multi") + # np.savez("test_model_multi/model", *model) + np.savez("test_model_multi/model", weights1 = weights1, bias1 = bias1, weights2 = weights2, bias2 = bias2) + + print "Model saved! Check test_model_multi directory" diff --git a/scripts/train_test_model.py b/scripts/train_test_model_single.py similarity index 91% rename from scripts/train_test_model.py rename to scripts/train_test_model_single.py index 4f3236a9..f3a769b4 100755 --- a/scripts/train_test_model.py +++ b/scripts/train_test_model_single.py @@ -84,8 +84,8 @@ if __name__ == "__main__": # print np.count_nonzero(lr)i ### Save model to npz files - if not os.path.exists("test_model"): - os.makedirs("test_model") - np.savez("test_model/model", weights = weights, bias = bias) + if not os.path.exists("test_model_single"): + os.makedirs("test_model_single") + np.savez("test_model_single/model", weights = weights, bias = bias) - print "Model saved! Check test_model directory" + print "Model saved! Check test_model_single directory" From e7ea44c689d39075d50d64c912b910ad4f81c756 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Wed, 14 Sep 2016 19:06:50 +0200 Subject: [PATCH 04/11] compile error --- src/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 11b36053..9d6e8bf4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -4,7 +4,6 @@ include_directories(.) add_library(libcommon OBJECT cnpy/cnpy.cpp exception.cpp - #npz_converter.cpp ) cuda_add_executable( From 823a3a624a240aedfc3b8fc765e228c40e81a9d8 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Wed, 14 Sep 2016 19:51:26 +0200 Subject: [PATCH 05/11] found not-working conf --- src/tensor.h | 103 ++++++++++---------- src/tensor_operators.cu | 6 +- src/test.cu | 207 +++++++++++++++++++--------------------- src/validate_mnist.cu | 44 ++++++--- 4 files changed, 182 insertions(+), 178 deletions(-) diff --git a/src/tensor.h b/src/tensor.h index b9c81a91..0f6029d8 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -13,27 +12,27 @@ namespace marian { -struct Handles { - cudnnHandle_t cudnnHandle; - cublasHandle_t cublasHandle; - - cudnnOpTensorDescriptor_t add; - - Handles() { - cudnnCreate(&cudnnHandle); - cublasCreate(&cublasHandle); - cudnnCreateOpTensorDescriptor(&add); - cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN); - } - - ~Handles() { - cudnnDestroy(cudnnHandle); - cublasDestroy(cublasHandle); - cudnnDestroyOpTensorDescriptor(add); - } -}; - -const Handles handles; +//struct Handles { +// //cudnnHandle_t cudnnHandle; +// //cublasHandle_t cublasHandle; +// +// //cudnnOpTensorDescriptor_t add; +// +// Handles() { +// cudnnCreate(&cudnnHandle); +// cublasCreate(&cublasHandle); +// cudnnCreateOpTensorDescriptor(&add); +// cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN); +// } +// +// ~Handles() { +// cudnnDestroy(cudnnHandle); +// cublasDestroy(cublasHandle); +// cudnnDestroyOpTensorDescriptor(add); +// } +//}; +// +//const Handles handles; // typedef std::vector Shape; @@ -60,17 +59,17 @@ class TensorImpl { private: Shape shape_; thrust::device_vector data_; - cudnnTensorDescriptor_t desc_; + //cudnnTensorDescriptor_t desc_; size_t tno_; static size_t tensorCounter; - cudnnDataType_t dataType() { - switch(sizeof(Float)) { - case 2: return CUDNN_DATA_HALF; - case 8: return CUDNN_DATA_DOUBLE; - default: return CUDNN_DATA_FLOAT; - } - } + //cudnnDataType_t dataType() { + // switch(sizeof(Float)) { + // case 2: return CUDNN_DATA_HALF; + // case 8: return CUDNN_DATA_DOUBLE; + // default: return CUDNN_DATA_FLOAT; + // } + //} public: typedef Float value_type; @@ -90,28 +89,28 @@ class TensorImpl { int size = GetTotalSize(shape_); data_.resize(size, value); - cudnnCreateTensorDescriptor(&desc_); - switch (shape_.size()) { - case 1: - cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - shape_[0], 1, 1, 1); break; - case 2: - cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - shape_[0], shape_[1], 1, 1); break; - case 3: - cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - shape_[0], shape_[1], shape_[2], 1); break; - case 4: - cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - shape_[0], shape_[1], shape_[2], shape_[3]); break; - } + //cudnnCreateTensorDescriptor(&desc_); + //switch (shape_.size()) { + // case 1: + // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + // shape_[0], 1, 1, 1); break; + // case 2: + // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + // shape_[0], shape_[1], 1, 1); break; + // case 3: + // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + // shape_[0], shape_[1], shape_[2], 1); break; + // case 4: + // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), + // shape_[0], shape_[1], shape_[2], shape_[3]); break; + //} } TensorImpl(const TensorImpl&) = delete; TensorImpl(TensorImpl&&) = delete; ~TensorImpl() { - cudnnDestroyTensorDescriptor(desc_); + //cudnnDestroyTensorDescriptor(desc_); } value_type operator[](size_t i) const { @@ -146,9 +145,9 @@ class TensorImpl { return thrust::raw_pointer_cast(data_.data()); } - cudnnTensorDescriptor_t desc() const { - return desc_; - } + //cudnnTensorDescriptor_t desc() const { + // return desc_; + //} size_t id() const { return tno_; @@ -246,9 +245,9 @@ class Tensor { return pimpl_->shape(); } - cudnnTensorDescriptor_t desc() const { - return pimpl_->desc(); - } + //cudnnTensorDescriptor_t desc() const { + // return pimpl_->desc(); + //} void set(value_type value) { pimpl_->set(value); diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index 2aa96331..e9e09ee6 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -130,7 +130,11 @@ Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, Tensor Prod(Tensor C, const Tensor A, const Tensor B, bool transA, bool transB, Float beta) { - return Prod(handles.cublasHandle, C, A, B, transA, transB, beta); + cublasHandle_t cublasHandle; + cublasCreate(&cublasHandle); + Tensor temp = Prod(cublasHandle, C, A, B, transA, transB, beta); + cublasDestroy(cublasHandle); + return temp; } } \ No newline at end of file diff --git a/src/test.cu b/src/test.cu index 0285e3a5..629c1bc2 100644 --- a/src/test.cu +++ b/src/test.cu @@ -2,9 +2,9 @@ #include "marian.h" #include "mnist.h" -using namespace std; - int main(int argc, char** argv) { + cudaSetDevice(0); + /*int numImg = 0;*/ /*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg);*/ /*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numImg);*/ @@ -12,117 +12,104 @@ int main(int argc, char** argv) { using namespace marian; using namespace keywords; - const size_t BATCH_SIZE = 500; - const size_t IMAGE_SIZE = 784; - const size_t LABEL_SIZE = 10; - - Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X"); - Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y"); + Expr x = input(shape={1, 2}); + Expr y = input(shape={1, 2}); - Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0"); - Expr b = param(shape={1, LABEL_SIZE}, name="b0"); - - Expr z = dot(x, w) + b; - Expr lr = softmax(z, axis=1, name="pred"); - Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); - //cerr << "x=" << Debug(lr.val().shape()) << endl; + Expr w = param(shape={2, 2}, name="W0"); + //Expr b = param(shape={1, 2}, name="b0"); - int numofdata; - //vector images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); - //vector labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); - vector images = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", numofdata, IMAGE_SIZE); - vector labels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", numofdata, LABEL_SIZE); - cerr << "images=" << images.size() << " labels=" << labels.size() << endl; - cerr << "numofdata=" << numofdata << endl; - - size_t startInd = 0; - size_t startIndData = 0; - while (startInd < numofdata) { - size_t batchSize = (startInd + BATCH_SIZE < numofdata) ? BATCH_SIZE : numofdata - startInd; - cerr << "startInd=" << startInd - << " startIndData=" << startIndData - << " batchSize=" << batchSize << endl; - - Tensor tx({numofdata, IMAGE_SIZE}, 1); - Tensor ty({numofdata, LABEL_SIZE}, 1); - - tx.Load(images.begin() + startIndData, images.begin() + startIndData + batchSize * IMAGE_SIZE); - ty.Load(labels.begin() + startInd, labels.begin() + startInd + batchSize); - - //cerr << "tx=" << Debug(tx.shape()) << endl; - //cerr << "ty=" << Debug(ty.shape()) << endl; - - x = tx; - y = ty; - - cerr << "x=" << Debug(x.val().shape()) << endl; - cerr << "y=" << Debug(y.val().shape()) << endl; - - - graph.forward(batchSize); - - cerr << "w=" << Debug(w.val().shape()) << endl; - cerr << "b=" << Debug(b.val().shape()) << endl; - std::cerr << "z: " << Debug(z.val().shape()) << endl; - std::cerr << "lr: " << Debug(lr.val().shape()) << endl; - std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ; - - //std::cerr << "scores=" << scores.val().Debug() << endl; - //std::cerr << "lr=" << lr.val().Debug() << endl; - - graph.backward(); - - //std::cerr << graph["pred"].val()[0] << std::endl; - - startInd += batchSize; - startIndData += batchSize * IMAGE_SIZE; - } + std::cerr << "Building model..."; + auto predict = softmax(dot(x, w), + axis=1, name="pred"); + auto graph = -mean(sum(y * log(predict), axis=1), + axis=0, name="cost"); - - // XOR - /* - Expr x = input(shape={whatevs, 2}, name="X"); - Expr y = input(shape={whatevs, 2}, name="Y"); - - Expr w = param(shape={2, 1}, name="W0"); - Expr b = param(shape={1, 1}, name="b0"); - - Expr n5 = dot(x, w); - Expr n6 = n5 + b; - Expr lr = softmax(n6, axis=1, name="pred"); - cerr << "lr=" << lr.Debug() << endl; - - Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); - - Tensor tx({4, 2}, 1); - Tensor ty({4, 1}, 1); - cerr << "tx=" << tx.Debug() << endl; - cerr << "ty=" << ty.Debug() << endl; - - tx.Load("../examples/xor/train.txt"); - ty.Load("../examples/xor/label.txt"); - */ - -#if 0 - hook0(graph); - graph.autodiff(); - std::cerr << graph["cost"].val()[0] << std::endl; - //hook1(graph); - for(auto p : graph.params()) { - auto update = _1 = _1 - alpha * _2; - Element(update, p.val(), p.grad()); - } - hook2(graph); + Tensor x1t({1, 2}); + std::vector xv = { 0.6, 0.1 }; + thrust::copy(xv.begin(), xv.end(), x1t.begin()); + + Tensor x2t({1, 2}); + std::vector yv = { 0, 1 }; + thrust::copy(yv.begin(), yv.end(), x2t.begin()); + + x = x1t; + y = x2t; + + graph.forward(1); + graph.backward(); + + std::cerr << graph.val().Debug() << std::endl; + std::cerr << w.grad().Debug() << std::endl; + //std::cerr << b.grad().Debug() << std::endl; + +// using namespace marian; +// using namespace keywords; +// +// const size_t BATCH_SIZE = 500; +// const size_t IMAGE_SIZE = 784; +// const size_t LABEL_SIZE = 10; +// +// Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X"); +// Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y"); +// +// Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0"); +// Expr b = param(shape={1, LABEL_SIZE}, name="b0"); +// +// Expr z = dot(x, w) + b; +// Expr lr = softmax(z, axis=1, name="pred"); +// Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); +// //cerr << "x=" << Debug(lr.val().shape()) << endl; +// +// int numofdata; +// //vector images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); +// //vector labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); +// vector images = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", numofdata, IMAGE_SIZE); +// vector labels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", numofdata, LABEL_SIZE); +// cerr << "images=" << images.size() << " labels=" << labels.size() << endl; +// cerr << "numofdata=" << numofdata << endl; +// +// size_t startInd = 0; +// size_t startIndData = 0; +// while (startInd < numofdata) { +// size_t batchSize = (startInd + BATCH_SIZE < numofdata) ? BATCH_SIZE : numofdata - startInd; +// cerr << "startInd=" << startInd +// << " startIndData=" << startIndData +// << " batchSize=" << batchSize << endl; +// +// Tensor tx({numofdata, IMAGE_SIZE}, 1); +// Tensor ty({numofdata, LABEL_SIZE}, 1); +// +// tx.Load(images.begin() + startIndData, images.begin() + startIndData + batchSize * IMAGE_SIZE); +// ty.Load(labels.begin() + startInd, labels.begin() + startInd + batchSize); +// +// //cerr << "tx=" << Debug(tx.shape()) << endl; +// //cerr << "ty=" << Debug(ty.shape()) << endl; +// +// x = tx; +// y = ty; +// +// cerr << "x=" << Debug(x.val().shape()) << endl; +// cerr << "y=" << Debug(y.val().shape()) << endl; +// +// +// graph.forward(batchSize); +// +// cerr << "w=" << Debug(w.val().shape()) << endl; +// cerr << "b=" << Debug(b.val().shape()) << endl; +// std::cerr << "z: " << Debug(z.val().shape()) << endl; +// std::cerr << "lr: " << Debug(lr.val().shape()) << endl; +// std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ; +// +// //std::cerr << "scores=" << scores.val().Debug() << endl; +// //std::cerr << "lr=" << lr.val().Debug() << endl; +// +// //graph.backward(); +// +// //std::cerr << graph["pred"].val()[0] << std::endl; +// +// startInd += batchSize; +// startIndData += batchSize * IMAGE_SIZE; +// } - auto opt = adadelta(cost_function=cost, - eta=0.9, gamma=0.1, - set_batch=set, - before_update=before, - after_update=after, - set_valid=valid, - validation_freq=100, - verbose=1, epochs=3, early_stopping=10); - opt.run(); -#endif return 0; } diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index a42fa881..023aba8b 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -7,13 +7,16 @@ using namespace marian; using namespace keywords; int main(int argc, char** argv) { + + cudaSetDevice(0); + const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; int numofdata; std::cerr << "Loading test set..."; std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); - std::vectortestLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); + std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); std::cerr << "\tDone." << std::endl; std::cerr << "Loading model params..."; @@ -27,11 +30,11 @@ int main(int argc, char** argv) { Shape bShape; converter.Load("bias", bData, bShape); - auto initW = [&wData](Tensor t) { + auto initW = [wData](Tensor t) { thrust::copy(wData.begin(), wData.end(), t.begin()); }; - auto initB = [&bData](Tensor t) { + auto initB = [bData](Tensor t) { thrust::copy(bData.begin(), bData.end(), t.begin()); }; @@ -39,24 +42,35 @@ int main(int argc, char** argv) { Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X"); - + Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y"); + Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW); Expr b = param(shape={1, LABEL_SIZE}, name="b0", init=initB); std::cerr << "Building model..."; - auto scores = dot(x, w) + b; - auto predict = softmax(scores, axis=1, name="pred"); + auto predict = softmax(dot(x, w) + b, + axis=1, name="pred"); + auto graph = -mean(sum(y * log(predict), axis=1), + axis=0, name="cost"); + std::cerr << "\tDone." << std::endl; Tensor xt({numofdata, IMAGE_SIZE}); xt.Load(testImages); - - predict.forward(numofdata); - + + Tensor yt({numofdata, LABEL_SIZE}); + yt.Load(testLabels); + + x = xt; + y = yt; + + graph.forward(numofdata); auto results = predict.val(); + graph.backward(); + + std::cerr << b.grad().Debug() << std::endl; size_t acc = 0; - for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) { size_t correct = 0; size_t predicted = 0; @@ -65,11 +79,11 @@ int main(int argc, char** argv) { if (results[i + j] > results[i + predicted]) predicted = j; } acc += (correct == predicted); - std::cerr << "corect: " << correct << " | " << predicted << "("; - for (size_t j = 0; j < LABEL_SIZE; ++j) { - std::cerr << results[i+j] << " "; - } - std::cerr << std::endl; + //std::cerr << "corect: " << correct << " | " << predicted << "("; + //for (size_t j = 0; j < LABEL_SIZE; ++j) { + // std::cerr << results[i+j] << " "; + //} + //std::cerr << std::endl; } std::cerr << "ACC: " << float(acc)/numofdata << std::endl; From 41bab63f70df6e0708fb29c75a1762995a9d4291 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Wed, 14 Sep 2016 19:56:08 +0200 Subject: [PATCH 06/11] cmake. Compiles -j without errors --- src/CMakeLists.txt | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9d6e8bf4..306b15c3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -6,33 +6,39 @@ add_library(libcommon OBJECT exception.cpp ) +cuda_add_library(common_cuda + expressions.cu + test.cu + tensor.cu + tensor_operators.cu +) + +target_link_libraries(common_cuda) + cuda_add_executable( marian - expressions.cu - test.cu - tensor.cu tensor_operators.cu $ ) +target_link_libraries(marian common_cuda) + cuda_add_executable( train_mnist - expressions.cu train_mnist.cu - tensor.cu - tensor_operators.cu $ ) +target_link_libraries(train_mnist common_cuda) + cuda_add_executable( validate_mnist - expressions.cu validate_mnist.cu - tensor.cu - tensor_operators.cu $ ) +target_link_libraries(validate_mnist common_cuda) + foreach(exec marian train_mnist validate_mnist) target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn) cuda_add_cublas_to_target(${exec}) From ab8bef8a03d075493dd7b72f5d250fed38ca2edc Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Wed, 14 Sep 2016 22:37:27 +0200 Subject: [PATCH 07/11] create marian lib --- src/CMakeLists.txt | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 306b15c3..d837834d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,43 +1,37 @@ include_directories(.) -add_library(libcommon OBJECT +cuda_add_library(marian_lib cnpy/cnpy.cpp exception.cpp -) - -cuda_add_library(common_cuda expressions.cu test.cu tensor.cu tensor_operators.cu ) -target_link_libraries(common_cuda) +target_link_libraries(marian_lib) cuda_add_executable( marian tensor_operators.cu - $ ) -target_link_libraries(marian common_cuda) +target_link_libraries(marian marian_lib) cuda_add_executable( train_mnist train_mnist.cu - $ ) -target_link_libraries(train_mnist common_cuda) +target_link_libraries(train_mnist marian_lib) cuda_add_executable( validate_mnist validate_mnist.cu - $ ) -target_link_libraries(validate_mnist common_cuda) +target_link_libraries(validate_mnist marian_lib) foreach(exec marian train_mnist validate_mnist) target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn) From 9d4bc1d1fdc9066ec21a75426be4b1cb9befc0e8 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Wed, 14 Sep 2016 22:42:51 +0200 Subject: [PATCH 08/11] create marian lib --- src/CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d837834d..365df389 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,7 +5,6 @@ cuda_add_library(marian_lib cnpy/cnpy.cpp exception.cpp expressions.cu - test.cu tensor.cu tensor_operators.cu ) @@ -14,7 +13,7 @@ target_link_libraries(marian_lib) cuda_add_executable( marian - tensor_operators.cu + test.cu ) target_link_libraries(marian marian_lib) From 5b9e6c05c5357f1ffd67e39acab81c453bf1a3f0 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Wed, 14 Sep 2016 23:00:02 +0200 Subject: [PATCH 09/11] eclipse --- marian/.cproject | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/marian/.cproject b/marian/.cproject index 195ef668..2d8c666b 100644 --- a/marian/.cproject +++ b/marian/.cproject @@ -56,11 +56,11 @@ - - + + - + From 0be1b07308f70af085f55b8a3a56a74209185bb7 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Wed, 14 Sep 2016 23:17:53 +0200 Subject: [PATCH 10/11] faster set/get --- src/sgd.h | 4 +-- src/tensor.cu | 83 ++----------------------------------------- src/tensor.h | 80 ++++++++--------------------------------- src/validate_mnist.cu | 29 ++++++++------- 4 files changed, 34 insertions(+), 162 deletions(-) diff --git a/src/sgd.h b/src/sgd.h index 298cd358..0dab8df0 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -60,8 +60,8 @@ class SGD { std::vector y(yData_.begin() + startId * numClasses_, yData_.begin() + endId * numClasses_); - xt.Load(x); - yt.Load(y); + xt.set(x); + yt.set(y); } void UpdateModel() { diff --git a/src/tensor.cu b/src/tensor.cu index 09355b21..fea21926 100644 --- a/src/tensor.cu +++ b/src/tensor.cu @@ -5,91 +5,12 @@ using namespace std; namespace marian { -inline std::vector Tokenize(const std::string& str, - const std::string& delimiters = " \t") -{ - std::vector tokens; - // Skip delimiters at beginning. - std::string::size_type lastPos = str.find_first_not_of(delimiters, 0); - // Find first "non-delimiter". - std::string::size_type pos = str.find_first_of(delimiters, lastPos); - - while (std::string::npos != pos || std::string::npos != lastPos) { - // Found a token, add it to the vector. - tokens.push_back(str.substr(lastPos, pos - lastPos)); - // Skip delimiters. Note the "not_of" - lastPos = str.find_first_not_of(delimiters, pos); - // Find next "non-delimiter" - pos = str.find_first_of(delimiters, lastPos); - } - - return tokens; -} - -//! convert string to variable of type T. Used to reading floats, int etc from files -template -T Scan(const std::string &input) -{ - std::stringstream stream(input); - T ret; - stream >> ret; - return ret; -} - -//! convert vectors of string to vectors of type T variables -template -inline std::vector Scan(const std::vector< std::string > &input) -{ - std::vector output(input.size()); - for (size_t i = 0 ; i < input.size() ; i++) { - output[i] = Scan( input[i] ); - } - return output; -} - -//! tokenise input string to vector of type T -template -inline std::vector Tokenize( const std::string &input - , const std::string& delimiters = " \t") -{ - std::vector stringVector = Tokenize(input, delimiters); - return Scan( stringVector ); -} - - -void Tensor::Load(const std::string &path) -{ - size_t totSize = GetTotalSize(pimpl_->shape()); - cerr << "totSize=" << totSize << endl; - std::vector hostData(totSize); - - fstream strm; - strm.open(path.c_str()); - - string line; - size_t ind = 0; - while ( getline (strm, line) ) - { - cerr << line << '\n'; - vector toks = Tokenize(line); - for (size_t i = 0; i < toks.size(); ++i) { - hostData[ind] = toks[i]; - } - - ++ind; - } - strm.close(); - - Load(hostData.begin(), hostData.begin()); -} - -void Tensor::Load(const std::vector& data) +void Tensor::set(const std::vector& data) { pimpl_->set(data.begin(), data.end()); } - -void Tensor::Load(const std::vector::const_iterator &begin, const std::vector::const_iterator &end) +void Tensor::set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end) { pimpl_->set(begin, end); } diff --git a/src/tensor.h b/src/tensor.h index 0f6029d8..af9069de 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -12,30 +12,6 @@ namespace marian { -//struct Handles { -// //cudnnHandle_t cudnnHandle; -// //cublasHandle_t cublasHandle; -// -// //cudnnOpTensorDescriptor_t add; -// -// Handles() { -// cudnnCreate(&cudnnHandle); -// cublasCreate(&cublasHandle); -// cudnnCreateOpTensorDescriptor(&add); -// cudnnSetOpTensorDescriptor(add, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN); -// } -// -// ~Handles() { -// cudnnDestroy(cudnnHandle); -// cublasDestroy(cublasHandle); -// cudnnDestroyOpTensorDescriptor(add); -// } -//}; -// -//const Handles handles; - -// typedef std::vector Shape; - inline std::string Debug(const Shape &shape) { std::stringstream strm; @@ -59,18 +35,9 @@ class TensorImpl { private: Shape shape_; thrust::device_vector data_; - //cudnnTensorDescriptor_t desc_; size_t tno_; static size_t tensorCounter; - //cudnnDataType_t dataType() { - // switch(sizeof(Float)) { - // case 2: return CUDNN_DATA_HALF; - // case 8: return CUDNN_DATA_DOUBLE; - // default: return CUDNN_DATA_FLOAT; - // } - //} - public: typedef Float value_type; @@ -89,30 +56,11 @@ class TensorImpl { int size = GetTotalSize(shape_); data_.resize(size, value); - //cudnnCreateTensorDescriptor(&desc_); - //switch (shape_.size()) { - // case 1: - // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - // shape_[0], 1, 1, 1); break; - // case 2: - // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - // shape_[0], shape_[1], 1, 1); break; - // case 3: - // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - // shape_[0], shape_[1], shape_[2], 1); break; - // case 4: - // cudnnSetTensor4dDescriptor(desc_, CUDNN_TENSOR_NCHW, dataType(), - // shape_[0], shape_[1], shape_[2], shape_[3]); break; - //} } TensorImpl(const TensorImpl&) = delete; TensorImpl(TensorImpl&&) = delete; - ~TensorImpl() { - //cudnnDestroyTensorDescriptor(desc_); - } - value_type operator[](size_t i) const { return data_[i]; } @@ -145,10 +93,6 @@ class TensorImpl { return thrust::raw_pointer_cast(data_.data()); } - //cudnnTensorDescriptor_t desc() const { - // return desc_; - //} - size_t id() const { return tno_; } @@ -158,12 +102,13 @@ class TensorImpl { } void set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end) { - size_t totSize = GetTotalSize(shape()); - //std::cerr << "tensor size=" << totSize << " vector size=" << values.size() << std::endl; - //assert(totSize == values.size()); thrust::copy(begin, end, data_.begin()); } + void get(std::vector::iterator out) { + thrust::copy(data_.begin(), data_.end(), out); + } + std::string Debug() const { std::stringstream strm; @@ -245,10 +190,6 @@ class Tensor { return pimpl_->shape(); } - //cudnnTensorDescriptor_t desc() const { - // return pimpl_->desc(); - //} - void set(value_type value) { pimpl_->set(value); } @@ -273,10 +214,17 @@ class Tensor { std::cerr << std::endl; } - void Load(const std::string &path); - void Load(const std::vector& data); - void Load(const std::vector::const_iterator &begin, const std::vector::const_iterator &end); + //void Load(const std::string &path); + void set(const std::vector& data); + void set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end); + void get(std::vector::iterator out) { + pimpl_->get(out); + } + + void get(std::vector &vout) { + pimpl_->get(vout.begin()); + } }; } diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 023aba8b..58697d46 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -31,21 +31,21 @@ int main(int argc, char** argv) { converter.Load("bias", bData, bShape); auto initW = [wData](Tensor t) { - thrust::copy(wData.begin(), wData.end(), t.begin()); + t.set(wData.begin(), wData.end()); }; auto initB = [bData](Tensor t) { - thrust::copy(bData.begin(), bData.end(), t.begin()); + t.set(bData.begin(), bData.end()); }; std::cerr << "\tDone." << std::endl; - Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X"); - Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y"); + auto x = input(shape={whatevs, IMAGE_SIZE}, name="X"); + auto y = input(shape={whatevs, LABEL_SIZE}, name="Y"); - Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW); - Expr b = param(shape={1, LABEL_SIZE}, name="b0", init=initB); + auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW); + auto b = param(shape={1, LABEL_SIZE}, name="b0", init=initB); std::cerr << "Building model..."; auto predict = softmax(dot(x, w) + b, @@ -53,13 +53,13 @@ int main(int argc, char** argv) { auto graph = -mean(sum(y * log(predict), axis=1), axis=0, name="cost"); - std::cerr << "\tDone." << std::endl; + std::cerr << "Done." << std::endl; Tensor xt({numofdata, IMAGE_SIZE}); - xt.Load(testImages); + xt.set(testImages); Tensor yt({numofdata, LABEL_SIZE}); - yt.Load(testLabels); + yt.set(testLabels); x = xt; y = yt; @@ -68,6 +68,9 @@ int main(int argc, char** argv) { auto results = predict.val(); graph.backward(); + std::vector resultsv(results.size()); + results.get(resultsv); + std::cerr << b.grad().Debug() << std::endl; size_t acc = 0; @@ -76,14 +79,14 @@ int main(int argc, char** argv) { size_t predicted = 0; for (size_t j = 0; j < LABEL_SIZE; ++j) { if (testLabels[i+j]) correct = j; - if (results[i + j] > results[i + predicted]) predicted = j; + if (resultsv[i + j] > resultsv[i + predicted]) predicted = j; } acc += (correct == predicted); - //std::cerr << "corect: " << correct << " | " << predicted << "("; + //std::cerr << correct << " | " << predicted << " ( "; //for (size_t j = 0; j < LABEL_SIZE; ++j) { - // std::cerr << results[i+j] << " "; + // std::cerr << resultsv[i+j] << " "; //} - //std::cerr << std::endl; + //std::cerr << ")" << std::endl; } std::cerr << "ACC: " << float(acc)/numofdata << std::endl; From aab15d66e637b330b4086ce1c22f5745cc31a53c Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Wed, 14 Sep 2016 23:31:11 +0200 Subject: [PATCH 11/11] operators --- src/tensor.cu | 11 +++++++++++ src/tensor.h | 8 ++++++-- src/validate_mnist.cu | 13 +++++-------- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/tensor.cu b/src/tensor.cu index fea21926..0c3e8a3e 100644 --- a/src/tensor.cu +++ b/src/tensor.cu @@ -15,5 +15,16 @@ void Tensor::set(const std::vector::const_iterator &begin, const std::vec pimpl_->set(begin, end); } +Tensor& operator<<(Tensor& t, const std::vector &vec) { + t.set(vec); + return t; +} + +std::vector& operator<<(std::vector &vec, const Tensor& t) { + t.get(vec); + return vec; +} + + } diff --git a/src/tensor.h b/src/tensor.h index af9069de..66690427 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -218,13 +218,17 @@ class Tensor { void set(const std::vector& data); void set(const std::vector::const_iterator &begin, const std::vector::const_iterator &end); - void get(std::vector::iterator out) { + void get(std::vector::iterator out) const { pimpl_->get(out); } - void get(std::vector &vout) { + void get(std::vector &vout) const { pimpl_->get(vout.begin()); } }; +Tensor& operator<<(Tensor& t, const std::vector &vec); + +std::vector& operator<<(std::vector &vec, const Tensor& t); + } diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 58697d46..510b8dd4 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -31,11 +31,11 @@ int main(int argc, char** argv) { converter.Load("bias", bData, bShape); auto initW = [wData](Tensor t) { - t.set(wData.begin(), wData.end()); + t.set(wData); }; auto initB = [bData](Tensor t) { - t.set(bData.begin(), bData.end()); + t.set(bData); }; std::cerr << "\tDone." << std::endl; @@ -56,20 +56,17 @@ int main(int argc, char** argv) { std::cerr << "Done." << std::endl; Tensor xt({numofdata, IMAGE_SIZE}); - xt.set(testImages); - Tensor yt({numofdata, LABEL_SIZE}); - yt.set(testLabels); - x = xt; - y = yt; + x = xt << testImages; + y = yt << testLabels; graph.forward(numofdata); auto results = predict.val(); graph.backward(); std::vector resultsv(results.size()); - results.get(resultsv); + resultsv << results; std::cerr << b.grad().Debug() << std::endl;