diff --git a/.gitignore b/.gitignore index 4dfd397b..53468680 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ build # Examples examples/*/*.gz +examples/mnist/*ubyte diff --git a/examples/mnist/Makefile b/examples/mnist/Makefile index 7e4e812f..26f65554 100644 --- a/examples/mnist/Makefile +++ b/examples/mnist/Makefile @@ -2,9 +2,12 @@ all: download -download: train-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz t10k-images-idx3-ubyte.gz t10k-labels-idx3-ubyte.gz +download: train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte -%.gz: +%-ubyte: %-ubyte.gz + gzip -d < $^ > $@ + +%-ubyte.gz: wget http://yann.lecun.com/exdb/mnist/$*.gz -O $@ clean: diff --git a/src/mnist.h b/src/mnist.h new file mode 100644 index 00000000..7727bacc --- /dev/null +++ b/src/mnist.h @@ -0,0 +1,94 @@ +#pragma once + +#include +#include +#include +#include + +namespace datasets { +namespace mnist { + +typedef unsigned char uchar; + +auto reverseInt = [](int i) { + unsigned char c1, c2, c3, c4; + c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255; + return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4; +}; + +std::vector> ReadImages(const std::string& full_path) { + std::ifstream file(full_path); + + if (! file.is_open()) + throw std::runtime_error("Cannot open file `" + full_path + "`!"); + + int magic_number = 0, n_rows = 0, n_cols = 0; + + file.read((char *)&magic_number, sizeof(magic_number)); + magic_number = reverseInt(magic_number); + + if (magic_number != 2051) + throw std::runtime_error("Invalid MNIST image file!"); + + int number_of_images = 0; + file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images); + file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows); + file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols); + + int image_size = n_rows * n_cols; + std::vector> _dataset(number_of_images, std::vector(image_size)); + unsigned char pixel = 0; + + for (int i = 0; i < number_of_images; i++) { + for (int j = 0; j < image_size; j++) { + file.read((char*)&pixel, sizeof(pixel)); + _dataset[i][j] = pixel / 255.0f; + } + } + return _dataset; +} + +std::vector ReadLabels(const std::string& full_path) { + std::ifstream file(full_path); + + if (! file.is_open()) + throw std::runtime_error("Cannot open file `" + full_path + "`!"); + + int magic_number = 0; + file.read((char *)&magic_number, sizeof(magic_number)); + magic_number = reverseInt(magic_number); + + if (magic_number != 2049) + throw std::runtime_error("Invalid MNIST label file!"); + + int number_of_labels = 0; + file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels); + + std::vector _dataset(number_of_labels); + for (int i = 0; i < number_of_labels; i++) { + file.read((char*)&_dataset[i], 1); + } + + return _dataset; +} + +} // namespace mnist +} // namespace datasets + + +//int main(int argc, const char *argv[]) { + //auto images = datasets::mnist::ReadImages("t10k-images-idx3-ubyte"); + //auto labels = datasets::mnist::ReadLabels("t10k-labels-idx1-ubyte"); + + //std::cout + //<< "Number of images: " << images.size() << std::endl + //<< "Image size: " << images[0].size() << std::endl; + + //for (int i = 0; i < 3; i++) { + //for (int j = 0; j < images[i].size(); j++) { + //std::cout << images[i][j] << ","; + //} + //std::cout << " label=" << (int)labels[i] << std::endl; + //} + //return 0; +//} diff --git a/src/test.cu b/src/test.cu index 4a2445fd..c2b0d62e 100644 --- a/src/test.cu +++ b/src/test.cu @@ -1,9 +1,13 @@ #include "marian.h" +#include "mnist.h" using namespace std; int main(int argc, char** argv) { + /*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte");*/ + /*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");*/ + /*std::cerr << images.size() << " " << images[0].size() << std::endl;*/ using namespace marian; using namespace keywords;