load data

This commit is contained in:
Hieu Hoang 2016-09-14 09:34:43 +02:00
commit 6c4d045f64
2 changed files with 70 additions and 57 deletions

View File

@ -4,49 +4,24 @@
#include <iostream>
#include <fstream>
#include <vector>
#include <cassert>
namespace datasets {
namespace mnist {
typedef unsigned char uchar;
const size_t IMAGE_MAGIC_NUMBER = 2051;
const size_t LABEL_MAGIC_NUMBER = 2049;
auto reverseInt = [](int i) {
unsigned char c1, c2, c3, c4;
c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255;
return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
};
std::vector<float> ReadImages(const std::string& full_path, int& number_of_images, int& image_size) {
std::ifstream file(full_path);
if (! file.is_open())
throw std::runtime_error("Cannot open file `" + full_path + "`!");
int magic_number = 0, n_rows = 0, n_cols = 0;
file.read((char *)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
if (magic_number != 2051)
throw std::runtime_error("Invalid MNIST image file!");
file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images);
file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows);
file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols);
image_size = n_rows * n_cols;
int n = number_of_images * image_size;
std::vector<float> _dataset(n);
unsigned char pixel = 0;
for (int i = 0; i < n; i++) {
file.read((char*)&pixel, sizeof(pixel));
_dataset[i] = pixel / 255.0f;
}
return _dataset;
}
std::vector<float> ReadLabels(const std::string& full_path) {
std::vector<float> ReadImages(const std::string& full_path, int& number_of_images, int imgSize) {
std::ifstream file(full_path);
if (! file.is_open())
@ -56,17 +31,50 @@ std::vector<float> ReadLabels(const std::string& full_path) {
file.read((char *)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
if (magic_number != 2049)
if (magic_number != IMAGE_MAGIC_NUMBER)
throw std::runtime_error("Invalid MNIST image file!");
int n_rows = 0;
int n_cols = 0;
file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images);
file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows);
file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols);
assert(n_rows * n_cols == imgSize);
int n = number_of_images * imgSize;
std::vector<float> _dataset(n);
for (int i = 0; i < n; i++) {
unsigned char pixel = 0;
file.read((char*)&pixel, sizeof(pixel));
_dataset[i] = pixel / 255.0f;
}
return _dataset;
}
std::vector<float> ReadLabels(const std::string& full_path, int& number_of_labels, int label_size) {
std::ifstream file(full_path);
if (! file.is_open())
throw std::runtime_error("Cannot open file `" + full_path + "`!");
int magic_number = 0;
file.read((char *)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
if (magic_number != LABEL_MAGIC_NUMBER)
throw std::runtime_error("Invalid MNIST label file!");
int number_of_labels = 0;
file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels);
std::vector<float> _dataset(number_of_labels);
int n = number_of_labels * label_size;
std::vector<float> _dataset(n, 0.0f);
for (int i = 0; i < number_of_labels; i++) {
int label;
unsigned char label;
file.read((char*)&label, 1);
_dataset[i] = label;
_dataset[(i * 10) + (int)(label)] = 1.0f;
}
return _dataset;
@ -77,19 +85,21 @@ std::vector<float> ReadLabels(const std::string& full_path) {
//int main(int argc, const char *argv[]) {
//int numImg, imgSize;
//auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg, imgSize);
//auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");
//int numImg = 0;
//auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg);
//auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numImg);
//std::cout
//<< "Number of images: " << numImg << std::endl
//<< "Image size: " << imgSize << std::endl;
//std::cout << "Number of images: " << numImg << std::endl;
//for (int i = 0; i < 3; i++) {
//for (int j = 0; j < imgSize; j++) {
//std::cout << images[(i * imgSize) + j] << ",";
//for (int j = 0; j < datasets::mnist::IMAGE_SIZE; j++) {
//std::cout << images[(i * datasets::mnist::IMAGE_SIZE) + j] << ",";
//}
//std::cout << " label=" << (int)labels[i] << std::endl;
//std::cout << "\nlabels= ";
//for (int k = 0; k < 10; k++) {
//std::cout << labels[(i * 10) + k] << ",";
//}
//std::cout << std::endl;
//}
//return 0;
//}

View File

@ -5,32 +5,35 @@
using namespace std;
int main(int argc, char** argv) {
/*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte");*/
/*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");*/
/*std::cerr << images.size() << " " << images[0].size() << std::endl;*/
/*int numImg = 0;*/
/*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg);*/
/*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numImg);*/
using namespace marian;
using namespace keywords;
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
Expr x = input(shape={whatevs, 784}, name="X");
Expr y = input(shape={whatevs, 10}, name="Y");
Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X");
Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y");
Expr w = param(shape={784, 10}, name="W0");
Expr b = param(shape={1, 10}, name="b0");
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
auto scores = dot(x, w) + b;
auto lr = softmax(scores, axis=1, name="pred");
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
cerr << "lr=" << lr.Debug() << endl;
int numImg, imgSize;
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg, imgSize);
vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");
int numofdata;
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
cerr << "images=" << images.size() << " labels=" << labels.size() << endl;
cerr << "numofdata=" << numofdata << endl;
Tensor tx({numImg, 784}, 1);
Tensor ty({numImg, 10}, 1);
Tensor tx({numofdata, IMAGE_SIZE}, 1);
Tensor ty({numofdata, LABEL_SIZE}, 1);
tx.Load(images);
ty.Load(labels);