diff --git a/.gitignore b/.gitignore index 4dfd397b..53468680 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ build # Examples examples/*/*.gz +examples/mnist/*ubyte diff --git a/examples/mnist/Makefile b/examples/mnist/Makefile index 7e4e812f..051d5a60 100644 --- a/examples/mnist/Makefile +++ b/examples/mnist/Makefile @@ -2,10 +2,13 @@ all: download -download: train-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz t10k-images-idx3-ubyte.gz t10k-labels-idx3-ubyte.gz +download: train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte -%.gz: - wget http://yann.lecun.com/exdb/mnist/$*.gz -O $@ +%-ubyte: %-ubyte.gz + gzip -d < $^ > $@ + +%-ubyte.gz: + wget http://yann.lecun.com/exdb/mnist/$*-ubyte.gz -O $@ clean: - rm -f *.gz + rm -f *.gz *-ubyte diff --git a/src/expression_operators.h b/src/expression_operators.h index 3c3dc031..8eabbd04 100644 --- a/src/expression_operators.h +++ b/src/expression_operators.h @@ -134,6 +134,7 @@ inline Expr sum(Expr a, Args ...args) { else if(ax == 1) { auto lshape = [n]() -> Shape { int cols = n->val().shape()[1]; + //std::cerr << "Shape will be " << cols << " by 1." << std::endl; return {cols, 1}; }; Expr one = ones(shape={n->shape()[1], 1}, @@ -153,6 +154,20 @@ inline Expr sum(Expr a, Args ...args) { template inline Expr softmax(Expr a, Args ...args) { Expr e = exp(a); +#if 0 + ChainPtr n = a.node(); + auto print_shape = [n]() -> Shape { + std::cerr << "Shape: "; + for (auto val : n->val().shape()) { + std::cerr << val << " "; + } + std::cerr << std::endl; + return {1,1}; + }; + using namespace keywords; + Expr one = ones(shape={1, 1}, lazy_shape=print_shape); +#endif + return e / sum(e, args...); } @@ -187,4 +202,4 @@ inline Expr mean(Expr a, Args ...args) { } } -} \ No newline at end of file +} diff --git a/src/graph_operators.h b/src/graph_operators.h index d07c4b38..30456153 100644 --- a/src/graph_operators.h +++ b/src/graph_operators.h @@ -118,9 +118,15 @@ struct LogNodeOp : public UnaryNodeOp { struct ExpNodeOp : public UnaryNodeOp { template - ExpNodeOp(Args ...args) - : UnaryNodeOp(args...) { } + ExpNodeOp(ChainPtr a, Args ...args) + : UnaryNodeOp(a, keywords::shape=newShape(a), + args...) { } + Shape newShape(ChainPtr a) { + Shape shape = a->shape(); + return shape; + } + void forward() { Element(_1 = Exp(_2), val_, a_->val()); } @@ -289,4 +295,4 @@ struct DivNodeOp : public BroadcastingNodeOp { } }; -} \ No newline at end of file +} diff --git a/src/mnist.h b/src/mnist.h new file mode 100644 index 00000000..7727bacc --- /dev/null +++ b/src/mnist.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include +#include + +namespace datasets { +namespace mnist { + +typedef unsigned char uchar; + +auto reverseInt = [](int i) { + unsigned char c1, c2, c3, c4; + c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255; + return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4; +}; + +std::vector> ReadImages(const std::string& full_path) { + std::ifstream file(full_path); + + if (! file.is_open()) + throw std::runtime_error("Cannot open file `" + full_path + "`!"); + + int magic_number = 0, n_rows = 0, n_cols = 0; + + file.read((char *)&magic_number, sizeof(magic_number)); + magic_number = reverseInt(magic_number); + + if (magic_number != 2051) + throw std::runtime_error("Invalid MNIST image file!"); + + int number_of_images = 0; + file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images); + file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows); + file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols); + + int image_size = n_rows * n_cols; + std::vector> _dataset(number_of_images, std::vector(image_size)); + unsigned char pixel = 0; + + for (int i = 0; i < number_of_images; i++) { + for (int j = 0; j < image_size; j++) { + file.read((char*)&pixel, sizeof(pixel)); + _dataset[i][j] = pixel / 255.0f; + } + } + return _dataset; +} + +std::vector ReadLabels(const std::string& full_path) { + std::ifstream file(full_path); + + if (! file.is_open()) + throw std::runtime_error("Cannot open file `" + full_path + "`!"); + + int magic_number = 0; + file.read((char *)&magic_number, sizeof(magic_number)); + magic_number = reverseInt(magic_number); + + if (magic_number != 2049) + throw std::runtime_error("Invalid MNIST label file!"); + + int number_of_labels = 0; + file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels); + + std::vector _dataset(number_of_labels); + for (int i = 0; i < number_of_labels; i++) { + file.read((char*)&_dataset[i], 1); + } + + return _dataset; +} + +} // namespace mnist +} // namespace datasets + + +//int main(int argc, const char *argv[]) { + //auto images = datasets::mnist::ReadImages("t10k-images-idx3-ubyte"); + //auto labels = datasets::mnist::ReadLabels("t10k-labels-idx1-ubyte"); + + //std::cout + //<< "Number of images: " << images.size() << std::endl + //<< "Image size: " << images[0].size() << std::endl; + + //for (int i = 0; i < 3; i++) { + //for (int j = 0; j < images[i].size(); j++) { + //std::cout << images[i][j] << ","; + //} + //std::cout << " label=" << (int)labels[i] << std::endl; + //} + //return 0; +//} diff --git a/src/test.cu b/src/test.cu index 4e382141..c2b0d62e 100644 --- a/src/test.cu +++ b/src/test.cu @@ -1,26 +1,28 @@ #include "marian.h" +#include "mnist.h" using namespace std; int main(int argc, char** argv) { + /*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte");*/ + /*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");*/ + /*std::cerr << images.size() << " " << images[0].size() << std::endl;*/ using namespace marian; using namespace keywords; - /* Expr x = input(shape={whatevs, 784}, name="X"); Expr y = input(shape={whatevs, 10}, name="Y"); Expr w = param(shape={784, 10}, name="W0"); Expr b = param(shape={1, 10}, name="b0"); - Expr n5 = dot(x, w); - Expr n6 = n5 + b; - Expr lr = softmax(n6, axis=1, name="pred"); + auto scores = dot(x, w) + b; + auto lr = softmax(scores, axis=1, name="pred"); + auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); cerr << "lr=" << lr.Debug() << endl; - Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); Tensor tx({500, 784}, 1); Tensor ty({500, 10}, 1); @@ -31,51 +33,47 @@ int main(int argc, char** argv) { y = ty; graph.forward(500); + + std::cerr << "Result: "; + for (auto val : scores.val().shape()) { + std::cerr << val << " "; + } + std::cerr << std::endl; + std::cerr << "Result: "; + for (auto val : lr.val().shape()) { + std::cerr << val << " "; + } + std::cerr << std::endl; + std::cerr << "Log-likelihood: "; + for (auto val : graph.val().shape()) { + std::cerr << val << " "; + } + std::cerr << std::endl; + + graph.backward(); + //std::cerr << graph["pred"].val()[0] << std::endl; - */ - - Expr x = input(shape={whatevs, 2}, name="X"); - Expr y = input(shape={whatevs, 2}, name="Y"); - - Expr w = param(shape={2, 1}, name="W0"); - Expr b = param(shape={1, 1}, name="b0"); - - Expr n5 = dot(x, w); - Expr n6 = n5 + b; - Expr lr = softmax(n6, axis=1, name="pred"); - cerr << "lr=" << lr.Debug() << endl; - - Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); - - Tensor tx({4, 2}, 1); - Tensor ty({4, 1}, 1); - cerr << "tx=" << tx.Debug() << endl; - cerr << "ty=" << ty.Debug() << endl; - - tx.Load("../examples/xor/train.txt"); - ty.Load("../examples/xor/label.txt"); - - - //hook0(graph); - //graph.autodiff(); - //std::cerr << graph["cost"].val()[0] << std::endl; +#if 0 + hook0(graph); + graph.autodiff(); + std::cerr << graph["cost"].val()[0] << std::endl; //hook1(graph); - //for(auto p : graph.params()) { - // auto update = _1 = _1 - alpha * _2; - // Element(update, p.val(), p.grad()); - //} - //hook2(graph); - // - //auto opt = adadelta(cost_function=cost, - // eta=0.9, gamma=0.1, - // set_batch=set, - // before_update=before, - // after_update=after, - // set_valid=valid, - // validation_freq=100, - // verbose=1, epochs=3, early_stopping=10); - //opt.run(); - + for(auto p : graph.params()) { + auto update = _1 = _1 - alpha * _2; + Element(update, p.val(), p.grad()); + } + hook2(graph); + + auto opt = adadelta(cost_function=cost, + eta=0.9, gamma=0.1, + set_batch=set, + before_update=before, + after_update=after, + set_valid=valid, + validation_freq=100, + verbose=1, epochs=3, early_stopping=10); + opt.run(); +#endif return 0; }