diff --git a/scripts/train_test_model_multi.py b/scripts/train_test_model_multi.py index 67ae0131..0bb71414 100755 --- a/scripts/train_test_model_multi.py +++ b/scripts/train_test_model_multi.py @@ -3,6 +3,7 @@ import sys import os import numpy as np +import time from keras.datasets import mnist from keras.utils import np_utils from keras.models import Sequential @@ -15,7 +16,11 @@ def softmax(x): def baseline_model(pixels_count, classes_count): model = Sequential() - model.add(Dense(100, input_dim=pixels_count, init='normal', activation='tanh')) + model.add(Dense(2000, input_dim=pixels_count, init='normal', activation='tanh')) + model.add(Dense(2000, init='normal', activation='tanh')) + model.add(Dense(2000, init='normal', activation='tanh')) + model.add(Dense(2000, init='normal', activation='tanh')) + model.add(Dense(2000, init='normal', activation='tanh')) model.add(Dense(classes_count, input_dim=100, init='normal', activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) return model @@ -52,21 +57,24 @@ if __name__ == "__main__": # Build the model model = baseline_model(pixels_count, classes_count) # Fit the model - model.fit(X_train, y_train, validation_data=(X_test, y_test), nb_epoch=10, batch_size=200, verbose=2) + + start = time.time(); + model.fit(X_train, y_train, nb_epoch=10, batch_size=200, verbose=2) + print "Time elapsed", time.time() - start, "s" # Final evaluation of the model scores = model.evaluate(X_test, y_test, verbose=0) - print("Baseline Error: %.2f%%" % (100-scores[1]*100)) + print("Accuracy: %.2f%%" % (scores[1] * 100)) ### Weight and bias matrixes - we extract them from the model # weights_ones = np.ones((pixels_count, classes_count)) # print weights_ones.shape - weights1, bias1, weights2, bias2 = model.get_weights() + #weights1, bias1, weights2, bias2 = model.get_weights() ### Save model to npz files - if not os.path.exists("test_model_multi"): - os.makedirs("test_model_multi") + #if not os.path.exists("test_model_multi"): + # os.makedirs("test_model_multi") # np.savez("test_model_multi/model", *model) - np.savez("test_model_multi/model", weights1 = weights1, bias1 = bias1, weights2 = weights2, bias2 = bias2) + #np.savez("test_model_multi/model", weights1 = weights1, bias1 = bias1, weights2 = weights2, bias2 = bias2) - print "Model saved! Check test_model_multi directory" + #print "Model saved! Check test_model_multi directory" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 401d4d1f..b23985fd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,30 +21,25 @@ cuda_add_executable( target_link_libraries(marian marian_lib) cuda_add_executable( - train_mnist - train_mnist.cu + mnist_benchmark + mnist_benchmark.cu ) -target_link_libraries(train_mnist marian_lib) - -cuda_add_executable( - validate_mnist - validate_mnist.cu -) cuda_add_executable( validate_mnist_batch validate_mnist_batch.cu ) + cuda_add_executable( validate_encoder_decoder validate_encoder_decoder.cu ) -target_link_libraries(validate_mnist marian_lib) +target_link_libraries(mnist_benchmark marian_lib) target_link_libraries(validate_mnist_batch marian_lib) target_link_libraries(validate_encoder_decoder marian_lib) -foreach(exec marian train_mnist validate_mnist validate_mnist_batch validate_encoder_decoder) +foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder) target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn) cuda_add_cublas_to_target(${exec}) set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") diff --git a/src/mnist_benchmark.cu b/src/mnist_benchmark.cu new file mode 100644 index 00000000..5f0bc705 --- /dev/null +++ b/src/mnist_benchmark.cu @@ -0,0 +1,164 @@ +#include +#include +#include + +#include "marian.h" +#include "mnist.h" +#include "npz_converter.h" +#include "optimizers.h" + +using namespace marian; +using namespace keywords; + +const size_t IMAGE_SIZE = 784; +const size_t LABEL_SIZE = 10; +int BATCH_SIZE = 200; + +ExpressionGraph build_graph(const std::vector& dims) { + std::cerr << "Building model... "; + boost::timer::cpu_timer timer; + + ExpressionGraph g; + auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); + auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); + + std::vector layers, weights, biases; + for(int i = 0; i < dims.size()-1; ++i) { + int in = dims[i]; + int out = dims[i+1]; + + if(i == 0) { + layers.emplace_back(x); + } + else { + layers.emplace_back(tanh(dot(layers.back(), weights.back())) + biases.back()); + } + + weights.emplace_back( + g.param(shape={in, out}, + init=normal())); + biases.emplace_back( + g.param(shape={1, out}, + init=normal())); + + } + + auto probs = named( + softmax(dot(layers.back(), weights.back()) + biases.back()), + "probs" + ); + + auto cost = -mean(sum(y * log(probs), axis=1), axis=0); + auto costreg = named( + cost, "cost" + ); + + std::cerr << timer.format(5, "%ws") << std::endl; + return g; +} + +void shuffle(std::vector& x, std::vector& y, size_t dimx, size_t dimy) { + std::srand(std::time(0)); + std::vector ind; + for(size_t i = 0; i < x.size() / dimx; ++i) { + ind.push_back(i); + } + + std::random_shuffle(ind.begin(), ind.end()); + + std::vector xShuffled(x.size()); + std::vector yShuffled(y.size()); + + int j = 0; + for(auto i : ind) { + std::copy(x.begin() + j * dimx, x.begin() + j * dimx + dimx, xShuffled.begin() + i * dimx); + std::copy(y.begin() + j * dimy, y.begin() + j * dimy + dimy, yShuffled.begin() + i * dimy); + j++; + } + + x = xShuffled; + y = yShuffled; + +} + +int main(int argc, char** argv) { + + int trainRows, testRows; + + std::cerr << "Loading train set..."; + std::vector trainImages = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", trainRows, IMAGE_SIZE); + std::vector trainLabels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", trainRows, LABEL_SIZE); + std::cerr << "Done." << std::endl; + + std::cerr << "Loading test set..."; + std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", testRows, IMAGE_SIZE); + std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", testRows, LABEL_SIZE); + std::cerr << "Done." << std::endl; + + ExpressionGraph g = build_graph({IMAGE_SIZE, 2000, 2000, 2000, 2000, 2000, LABEL_SIZE}); + + Tensor xt({BATCH_SIZE, IMAGE_SIZE}); + Tensor yt({BATCH_SIZE, LABEL_SIZE}); + + + boost::timer::cpu_timer total; + Adam opt; + for(int i = 1; i <= 10; ++i) { + boost::timer::cpu_timer timer; + shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE); + float cost = 0; + for(int j = 0; j < trainRows / BATCH_SIZE; j++) { + size_t xBatch = IMAGE_SIZE * BATCH_SIZE; + auto xbegin = trainImages.begin() + j * xBatch; + auto xend = xbegin + xBatch; + xt.set(xbegin, xend); + + size_t yBatch = LABEL_SIZE * BATCH_SIZE; + auto ybegin = trainLabels.begin() + j * yBatch; + auto yend = ybegin + yBatch; + yt.set(ybegin, yend); + + g["x"] = xt; + g["y"] = yt; + + opt(g, BATCH_SIZE); + cost += g["cost"].val()[0]; + } + std::cerr << "Epoch: " << i << " - Cost: " << cost / trainRows * BATCH_SIZE << " - " << timer.format(3, "%ws") << std::endl; + } + std::cerr << "Total: " << total.format(3, "%ws") << std::endl; + + std::vector results; + for(int j = 0; j < testRows / BATCH_SIZE; j++) { + size_t xBatch = IMAGE_SIZE * BATCH_SIZE; + auto xbegin = testImages.begin() + j * xBatch; + auto xend = xbegin + xBatch; + xt.set(xbegin, xend); + yt.set(0); + + g["x"] = xt; + g["y"] = yt; + + g.forward(BATCH_SIZE); + + std::vector bResults; + bResults << g["probs"].val(); + results.insert(results.end(), bResults.begin(), bResults.end()); + } + + size_t acc = 0; + for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) { + size_t correct = 0; + size_t proposed = 0; + for (size_t j = 0; j < LABEL_SIZE; ++j) { + if (testLabels[i + j]) + correct = j; + if (results[i + j] > results[i + proposed]) + proposed = j; + } + acc += (correct == proposed); + } + std::cerr << "Accuracy: " << float(acc) / testRows << std::endl; + + return 0; +} diff --git a/src/node_operators.h b/src/node_operators.h index c444f24f..db2031e9 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -114,7 +114,7 @@ struct LogitNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -138,7 +138,7 @@ struct TanhNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -180,7 +180,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -209,7 +209,7 @@ struct ArgmaxNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -232,7 +232,7 @@ struct LogNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -255,7 +255,7 @@ struct ExpNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -277,7 +277,7 @@ struct NegNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -330,8 +330,8 @@ struct DotNodeOp : public BinaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; - ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -357,8 +357,8 @@ struct PlusNodeOp : public BinaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; - ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -384,8 +384,8 @@ struct MinusNodeOp : public BinaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; - ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -411,8 +411,8 @@ struct MultNodeOp : public BinaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"•\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; - ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; diff --git a/src/optimizers.h b/src/optimizers.h index 184b063f..4e82e6d7 100644 --- a/src/optimizers.h +++ b/src/optimizers.h @@ -6,9 +6,15 @@ namespace marian { +// @TODO: modify computation graph to group all paramters in single matrix object. +// This will allow to perform a single large SGD update per batch. Currently there +// are as many updates as different paramters. + +// @TODO: Implement Element(...) with multiple functors for compacting of calls. + class Sgd { public: - Sgd(float eta=0.001) : eta_(eta) {} + Sgd(float eta=0.01) : eta_(eta) {} void operator()(ExpressionGraph& graph, int batchSize) { graph.backprop(batchSize); @@ -25,7 +31,7 @@ class Sgd { // @TODO: Add serialization for historic gradients and parameters class Adagrad { public: - Adagrad(float eta=0.001, float eps=10e-8) + Adagrad(float eta=0.01, float eps=1e-8) : eta_(eta), eps_(eps) {} void operator()(ExpressionGraph& graph, int batchSize) { @@ -55,7 +61,7 @@ class Adagrad { // https://arxiv.org/pdf/1412.6980v8.pdf class Adam { public: - Adam(float eta=0.001, float beta1=0.999, float beta2=0.999, float eps=10e-8) + Adam(float eta=0.001, float beta1=0.9, float beta2=0.999, float eps=1e-8) : eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {} void operator()(ExpressionGraph& graph, int batchSize) { @@ -95,4 +101,4 @@ class Adam { std::vector vt_; }; -} +} \ No newline at end of file diff --git a/src/param_initializers.h b/src/param_initializers.h index 084e829c..4d95a35c 100644 --- a/src/param_initializers.h +++ b/src/param_initializers.h @@ -33,13 +33,13 @@ void distribution(Tensor t, float a, float b) { t << vals; } -std::function normal(float mean = 0.0, float std = 0.1) { +std::function normal(float mean = 0.0, float std = 0.05) { return [mean, std](Tensor t) { distribution>(t, mean, std); }; } -std::function uniform(float a = 0.0, float b = 0.1) { +std::function uniform(float a = 0.0, float b = 0.05) { return [a, b](Tensor t) { distribution>(t, a, b); }; diff --git a/src/tensor_operators.cu b/src/tensor_operators.cu index 34ab874a..ad30d051 100644 --- a/src/tensor_operators.cu +++ b/src/tensor_operators.cu @@ -4,6 +4,14 @@ using namespace std; namespace marian { +// @TODO: handle this better, maybe per thread? +static cublasHandle_t create_handle() { + cublasHandle_t cublasHandle; + cublasCreate(&cublasHandle); + return cublasHandle; +} +cublasHandle_t cublasHandle = create_handle(); + __global__ void gSubtractMean(float* out, float* weights, size_t rows, size_t cols) { for(int bid = 0; bid < rows; bid += gridDim.x) { @@ -212,10 +220,7 @@ Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B, Tensor Prod(Tensor C, const Tensor A, const Tensor B, bool transA, bool transB, Float beta) { - cublasHandle_t cublasHandle; - cublasCreate(&cublasHandle); Tensor temp = Prod(cublasHandle, C, A, B, transA, transB, beta); - cublasDestroy(cublasHandle); return temp; } diff --git a/src/train_mnist.cu b/src/train_mnist.cu deleted file mode 100644 index 5b32cc9f..00000000 --- a/src/train_mnist.cu +++ /dev/null @@ -1,34 +0,0 @@ - -#include "marian.h" -#include "mnist.h" -#include "optimizers.h" - -int main(int argc, char** argv) { - const size_t IMAGE_SIZE = 784; - const size_t LABEL_SIZE = 10; - int numofdata; - - std::vector trainImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); - std::vector trainLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); - - using namespace marian; - using namespace keywords; - - ExpressionGraph g; - - Expr x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); - Expr y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); - - Expr w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE}), "w"); - Expr b = named(g.param(shape={1, LABEL_SIZE}), "b"); - - auto scores = dot(x, w) + b; - auto lr = softmax(scores); - auto cost = named(-mean(sum(y * log(lr), axis=1), axis=0), "cost"); - std::cerr << "lr=" << lr.Debug() << std::endl; - - Adagrad opt; - opt(g, 300); - - return 0; -} diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu deleted file mode 100644 index 34364929..00000000 --- a/src/validate_mnist.cu +++ /dev/null @@ -1,90 +0,0 @@ - -#include "marian.h" -#include "mnist.h" -#include "npz_converter.h" -#include "optimizers.h" - -using namespace marian; -using namespace keywords; - -const size_t IMAGE_SIZE = 784; -const size_t LABEL_SIZE = 10; -int BATCH_SIZE = 10000; - -ExpressionGraph build_graph() { - std::cerr << "Loading model params..."; - NpzConverter converter("../scripts/test_model_single/model.npz"); - - std::vector wData, bData; - Shape wShape, bShape; - converter.Load("weights", wData, wShape); - converter.Load("bias", bData, bShape); - std::cerr << "Done." << std::endl; - - std::cerr << "Building model..."; - - ExpressionGraph g; - auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); - auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); - - auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE}, - init=from_vector(wData)), "w"); - auto b = named(g.param(shape={1, LABEL_SIZE}, - init=from_vector(bData)), "b"); - - - auto probs = named( - softmax(dot(x, w) + b), - "probs" - ); - - auto cost = named( - -mean(sum(y * log(probs), axis=1), axis=0), - "cost" - ); - - std::cerr << "Done." << std::endl; - return g; -} - -int main(int argc, char** argv) { - - std::cerr << "Loading test set..."; - std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE); - std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE); - std::cerr << "Done." << std::endl; - - ExpressionGraph g = build_graph(); - - Tensor xt({BATCH_SIZE, IMAGE_SIZE}); - Tensor yt({BATCH_SIZE, LABEL_SIZE}); - - g["x"] = (xt << testImages); - g["y"] = (yt << testLabels); - - Adam opt; - for(size_t j = 0; j < 20; ++j) { - for(size_t i = 0; i < 60; ++i) { - opt(g, BATCH_SIZE); - } - std::cerr << g["cost"].val()[0] << std::endl; - } - - //std::cout << g.graphviz() << std::endl; - - std::vector results; - results << g["probs"].val(); - - size_t acc = 0; - for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) { - size_t correct = 0; - size_t proposed = 0; - for (size_t j = 0; j < LABEL_SIZE; ++j) { - if (testLabels[i+j]) correct = j; - if (results[i + j] > results[i + proposed]) proposed = j; - } - acc += (correct == proposed); - } - std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl; - return 0; -}