diff --git a/src/mnist.h b/src/mnist.h index 9f867ee8..43931ed2 100644 --- a/src/mnist.h +++ b/src/mnist.h @@ -4,49 +4,24 @@ #include #include #include +#include namespace datasets { namespace mnist { typedef unsigned char uchar; + +const size_t IMAGE_MAGIC_NUMBER = 2051; +const size_t LABEL_MAGIC_NUMBER = 2049; + auto reverseInt = [](int i) { unsigned char c1, c2, c3, c4; c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255; return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4; }; -std::vector ReadImages(const std::string& full_path, int& number_of_images, int& image_size) { - std::ifstream file(full_path); - - if (! file.is_open()) - throw std::runtime_error("Cannot open file `" + full_path + "`!"); - - int magic_number = 0, n_rows = 0, n_cols = 0; - - file.read((char *)&magic_number, sizeof(magic_number)); - magic_number = reverseInt(magic_number); - - if (magic_number != 2051) - throw std::runtime_error("Invalid MNIST image file!"); - - file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images); - file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows); - file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols); - - image_size = n_rows * n_cols; - int n = number_of_images * image_size; - std::vector _dataset(n); - unsigned char pixel = 0; - - for (int i = 0; i < n; i++) { - file.read((char*)&pixel, sizeof(pixel)); - _dataset[i] = pixel / 255.0f; - } - return _dataset; -} - -std::vector ReadLabels(const std::string& full_path) { +std::vector ReadImages(const std::string& full_path, int& number_of_images, int imgSize) { std::ifstream file(full_path); if (! file.is_open()) @@ -56,17 +31,50 @@ std::vector ReadLabels(const std::string& full_path) { file.read((char *)&magic_number, sizeof(magic_number)); magic_number = reverseInt(magic_number); - if (magic_number != 2049) + if (magic_number != IMAGE_MAGIC_NUMBER) + throw std::runtime_error("Invalid MNIST image file!"); + + int n_rows = 0; + int n_cols = 0; + file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images); + file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows); + file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols); + + assert(n_rows * n_cols == imgSize); + + int n = number_of_images * imgSize; + std::vector _dataset(n); + + for (int i = 0; i < n; i++) { + unsigned char pixel = 0; + file.read((char*)&pixel, sizeof(pixel)); + _dataset[i] = pixel / 255.0f; + } + return _dataset; +} + +std::vector ReadLabels(const std::string& full_path, int& number_of_labels, int label_size) { + std::ifstream file(full_path); + + if (! file.is_open()) + throw std::runtime_error("Cannot open file `" + full_path + "`!"); + + int magic_number = 0; + file.read((char *)&magic_number, sizeof(magic_number)); + magic_number = reverseInt(magic_number); + + if (magic_number != LABEL_MAGIC_NUMBER) throw std::runtime_error("Invalid MNIST label file!"); - int number_of_labels = 0; file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels); - std::vector _dataset(number_of_labels); + int n = number_of_labels * label_size; + std::vector _dataset(n, 0.0f); + for (int i = 0; i < number_of_labels; i++) { - int label; + unsigned char label; file.read((char*)&label, 1); - _dataset[i] = label; + _dataset[(i * 10) + (int)(label)] = 1.0f; } return _dataset; @@ -77,19 +85,21 @@ std::vector ReadLabels(const std::string& full_path) { //int main(int argc, const char *argv[]) { - //int numImg, imgSize; - //auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg, imgSize); - //auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte"); + //int numImg = 0; + //auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg); + //auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numImg); - //std::cout - //<< "Number of images: " << numImg << std::endl - //<< "Image size: " << imgSize << std::endl; + //std::cout << "Number of images: " << numImg << std::endl; //for (int i = 0; i < 3; i++) { - //for (int j = 0; j < imgSize; j++) { - //std::cout << images[(i * imgSize) + j] << ","; + //for (int j = 0; j < datasets::mnist::IMAGE_SIZE; j++) { + //std::cout << images[(i * datasets::mnist::IMAGE_SIZE) + j] << ","; //} - //std::cout << " label=" << (int)labels[i] << std::endl; + //std::cout << "\nlabels= "; + //for (int k = 0; k < 10; k++) { + //std::cout << labels[(i * 10) + k] << ","; + //} + //std::cout << std::endl; //} //return 0; //} diff --git a/src/test.cu b/src/test.cu index 4ef9953f..90dffa04 100644 --- a/src/test.cu +++ b/src/test.cu @@ -5,32 +5,35 @@ using namespace std; int main(int argc, char** argv) { - /*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte");*/ - /*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");*/ - /*std::cerr << images.size() << " " << images[0].size() << std::endl;*/ + /*int numImg = 0;*/ + /*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg);*/ + /*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numImg);*/ using namespace marian; using namespace keywords; + const size_t IMAGE_SIZE = 784; + const size_t LABEL_SIZE = 10; - Expr x = input(shape={whatevs, 784}, name="X"); - Expr y = input(shape={whatevs, 10}, name="Y"); + Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X"); + Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y"); - Expr w = param(shape={784, 10}, name="W0"); - Expr b = param(shape={1, 10}, name="b0"); + Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0"); + Expr b = param(shape={1, LABEL_SIZE}, name="b0"); auto scores = dot(x, w) + b; auto lr = softmax(scores, axis=1, name="pred"); auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); cerr << "lr=" << lr.Debug() << endl; - int numImg, imgSize; - vector images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg, imgSize); - vector labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte"); + int numofdata; + vector images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); + vector labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); cerr << "images=" << images.size() << " labels=" << labels.size() << endl; + cerr << "numofdata=" << numofdata << endl; - Tensor tx({numImg, 784}, 1); - Tensor ty({numImg, 10}, 1); + Tensor tx({numofdata, IMAGE_SIZE}, 1); + Tensor ty({numofdata, LABEL_SIZE}, 1); tx.Load(images); ty.Load(labels);