diff --git a/src/mnist.h b/src/mnist.h index 7727bacc..8e94931f 100644 --- a/src/mnist.h +++ b/src/mnist.h @@ -1,4 +1,4 @@ -#pragma once +//#pragma once #include #include @@ -16,7 +16,7 @@ auto reverseInt = [](int i) { return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4; }; -std::vector> ReadImages(const std::string& full_path) { +std::vector ReadImages(const std::string& full_path, int& number_of_images, int& image_size) { std::ifstream file(full_path); if (! file.is_open()) @@ -30,20 +30,18 @@ std::vector> ReadImages(const std::string& full_path) { if (magic_number != 2051) throw std::runtime_error("Invalid MNIST image file!"); - int number_of_images = 0; file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images); file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows); file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols); - int image_size = n_rows * n_cols; - std::vector> _dataset(number_of_images, std::vector(image_size)); + image_size = n_rows * n_cols; + int n = number_of_images * image_size; + std::vector _dataset(n); unsigned char pixel = 0; - for (int i = 0; i < number_of_images; i++) { - for (int j = 0; j < image_size; j++) { - file.read((char*)&pixel, sizeof(pixel)); - _dataset[i][j] = pixel / 255.0f; - } + for (int i = 0; i < n; i++) { + file.read((char*)&pixel, sizeof(pixel)); + _dataset[i] = pixel / 255.0f; } return _dataset; } @@ -77,16 +75,17 @@ std::vector ReadLabels(const std::string& full_path) { //int main(int argc, const char *argv[]) { - //auto images = datasets::mnist::ReadImages("t10k-images-idx3-ubyte"); - //auto labels = datasets::mnist::ReadLabels("t10k-labels-idx1-ubyte"); + //int numImg, imgSize; + //auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg, imgSize); + //auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte"); //std::cout - //<< "Number of images: " << images.size() << std::endl - //<< "Image size: " << images[0].size() << std::endl; + //<< "Number of images: " << numImg << std::endl + //<< "Image size: " << imgSize << std::endl; //for (int i = 0; i < 3; i++) { - //for (int j = 0; j < images[i].size(); j++) { - //std::cout << images[i][j] << ","; + //for (int j = 0; j < imgSize; j++) { + //std::cout << images[(i * imgSize) + j] << ","; //} //std::cout << " label=" << (int)labels[i] << std::endl; //}