mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
add functions loading MNIST dataset
This commit is contained in:
parent
803a562d4b
commit
cc7a48310f
1
.gitignore
vendored
1
.gitignore
vendored
@ -39,3 +39,4 @@ build
|
|||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
examples/*/*.gz
|
examples/*/*.gz
|
||||||
|
examples/mnist/*ubyte
|
||||||
|
@ -2,9 +2,12 @@
|
|||||||
|
|
||||||
all: download
|
all: download
|
||||||
|
|
||||||
download: train-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz t10k-images-idx3-ubyte.gz t10k-labels-idx3-ubyte.gz
|
download: train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
|
||||||
|
|
||||||
%.gz:
|
%-ubyte: %-ubyte.gz
|
||||||
|
gzip -d < $^ > $@
|
||||||
|
|
||||||
|
%-ubyte.gz:
|
||||||
wget http://yann.lecun.com/exdb/mnist/$*.gz -O $@
|
wget http://yann.lecun.com/exdb/mnist/$*.gz -O $@
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
|
94
src/mnist.h
Normal file
94
src/mnist.h
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace datasets {
|
||||||
|
namespace mnist {
|
||||||
|
|
||||||
|
typedef unsigned char uchar;
|
||||||
|
|
||||||
|
auto reverseInt = [](int i) {
|
||||||
|
unsigned char c1, c2, c3, c4;
|
||||||
|
c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255;
|
||||||
|
return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> ReadImages(const std::string& full_path) {
|
||||||
|
std::ifstream file(full_path);
|
||||||
|
|
||||||
|
if (! file.is_open())
|
||||||
|
throw std::runtime_error("Cannot open file `" + full_path + "`!");
|
||||||
|
|
||||||
|
int magic_number = 0, n_rows = 0, n_cols = 0;
|
||||||
|
|
||||||
|
file.read((char *)&magic_number, sizeof(magic_number));
|
||||||
|
magic_number = reverseInt(magic_number);
|
||||||
|
|
||||||
|
if (magic_number != 2051)
|
||||||
|
throw std::runtime_error("Invalid MNIST image file!");
|
||||||
|
|
||||||
|
int number_of_images = 0;
|
||||||
|
file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images);
|
||||||
|
file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows);
|
||||||
|
file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols);
|
||||||
|
|
||||||
|
int image_size = n_rows * n_cols;
|
||||||
|
std::vector<std::vector<float>> _dataset(number_of_images, std::vector<float>(image_size));
|
||||||
|
unsigned char pixel = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < number_of_images; i++) {
|
||||||
|
for (int j = 0; j < image_size; j++) {
|
||||||
|
file.read((char*)&pixel, sizeof(pixel));
|
||||||
|
_dataset[i][j] = pixel / 255.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return _dataset;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> ReadLabels(const std::string& full_path) {
|
||||||
|
std::ifstream file(full_path);
|
||||||
|
|
||||||
|
if (! file.is_open())
|
||||||
|
throw std::runtime_error("Cannot open file `" + full_path + "`!");
|
||||||
|
|
||||||
|
int magic_number = 0;
|
||||||
|
file.read((char *)&magic_number, sizeof(magic_number));
|
||||||
|
magic_number = reverseInt(magic_number);
|
||||||
|
|
||||||
|
if (magic_number != 2049)
|
||||||
|
throw std::runtime_error("Invalid MNIST label file!");
|
||||||
|
|
||||||
|
int number_of_labels = 0;
|
||||||
|
file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels);
|
||||||
|
|
||||||
|
std::vector<int> _dataset(number_of_labels);
|
||||||
|
for (int i = 0; i < number_of_labels; i++) {
|
||||||
|
file.read((char*)&_dataset[i], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return _dataset;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mnist
|
||||||
|
} // namespace datasets
|
||||||
|
|
||||||
|
|
||||||
|
//int main(int argc, const char *argv[]) {
|
||||||
|
//auto images = datasets::mnist::ReadImages("t10k-images-idx3-ubyte");
|
||||||
|
//auto labels = datasets::mnist::ReadLabels("t10k-labels-idx1-ubyte");
|
||||||
|
|
||||||
|
//std::cout
|
||||||
|
//<< "Number of images: " << images.size() << std::endl
|
||||||
|
//<< "Image size: " << images[0].size() << std::endl;
|
||||||
|
|
||||||
|
//for (int i = 0; i < 3; i++) {
|
||||||
|
//for (int j = 0; j < images[i].size(); j++) {
|
||||||
|
//std::cout << images[i][j] << ",";
|
||||||
|
//}
|
||||||
|
//std::cout << " label=" << (int)labels[i] << std::endl;
|
||||||
|
//}
|
||||||
|
//return 0;
|
||||||
|
//}
|
@ -1,9 +1,13 @@
|
|||||||
|
|
||||||
#include "marian.h"
|
#include "marian.h"
|
||||||
|
#include "mnist.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
|
/*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte");*/
|
||||||
|
/*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");*/
|
||||||
|
/*std::cerr << images.size() << " " << images[0].size() << std::endl;*/
|
||||||
|
|
||||||
using namespace marian;
|
using namespace marian;
|
||||||
using namespace keywords;
|
using namespace keywords;
|
||||||
|
Loading…
Reference in New Issue
Block a user