mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
merge
This commit is contained in:
commit
8f25a1d4bf
1
.gitignore
vendored
1
.gitignore
vendored
@ -39,3 +39,4 @@ build
|
||||
|
||||
# Examples
|
||||
examples/*/*.gz
|
||||
examples/mnist/*ubyte
|
||||
|
@ -2,10 +2,13 @@
|
||||
|
||||
all: download
|
||||
|
||||
download: train-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz t10k-images-idx3-ubyte.gz t10k-labels-idx3-ubyte.gz
|
||||
download: train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
|
||||
|
||||
%.gz:
|
||||
wget http://yann.lecun.com/exdb/mnist/$*.gz -O $@
|
||||
%-ubyte: %-ubyte.gz
|
||||
gzip -d < $^ > $@
|
||||
|
||||
%-ubyte.gz:
|
||||
wget http://yann.lecun.com/exdb/mnist/$*-ubyte.gz -O $@
|
||||
|
||||
clean:
|
||||
rm -f *.gz
|
||||
rm -f *.gz *-ubyte
|
||||
|
@ -134,6 +134,7 @@ inline Expr sum(Expr a, Args ...args) {
|
||||
else if(ax == 1) {
|
||||
auto lshape = [n]() -> Shape {
|
||||
int cols = n->val().shape()[1];
|
||||
//std::cerr << "Shape will be " << cols << " by 1." << std::endl;
|
||||
return {cols, 1};
|
||||
};
|
||||
Expr one = ones(shape={n->shape()[1], 1},
|
||||
@ -153,6 +154,20 @@ inline Expr sum(Expr a, Args ...args) {
|
||||
template <typename ...Args>
|
||||
inline Expr softmax(Expr a, Args ...args) {
|
||||
Expr e = exp(a);
|
||||
#if 0
|
||||
ChainPtr n = a.node();
|
||||
auto print_shape = [n]() -> Shape {
|
||||
std::cerr << "Shape: ";
|
||||
for (auto val : n->val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
return {1,1};
|
||||
};
|
||||
using namespace keywords;
|
||||
Expr one = ones(shape={1, 1}, lazy_shape=print_shape);
|
||||
#endif
|
||||
|
||||
return e / sum(e, args...);
|
||||
}
|
||||
|
||||
|
@ -118,8 +118,14 @@ struct LogNodeOp : public UnaryNodeOp {
|
||||
|
||||
struct ExpNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ExpNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
ExpNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
return shape;
|
||||
}
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Exp(_2), val_, a_->val());
|
||||
|
94
src/mnist.h
Normal file
94
src/mnist.h
Normal file
@ -0,0 +1,94 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
namespace datasets {
|
||||
namespace mnist {
|
||||
|
||||
typedef unsigned char uchar;
|
||||
|
||||
auto reverseInt = [](int i) {
|
||||
unsigned char c1, c2, c3, c4;
|
||||
c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255;
|
||||
return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
|
||||
};
|
||||
|
||||
std::vector<std::vector<float>> ReadImages(const std::string& full_path) {
|
||||
std::ifstream file(full_path);
|
||||
|
||||
if (! file.is_open())
|
||||
throw std::runtime_error("Cannot open file `" + full_path + "`!");
|
||||
|
||||
int magic_number = 0, n_rows = 0, n_cols = 0;
|
||||
|
||||
file.read((char *)&magic_number, sizeof(magic_number));
|
||||
magic_number = reverseInt(magic_number);
|
||||
|
||||
if (magic_number != 2051)
|
||||
throw std::runtime_error("Invalid MNIST image file!");
|
||||
|
||||
int number_of_images = 0;
|
||||
file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images);
|
||||
file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows);
|
||||
file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols);
|
||||
|
||||
int image_size = n_rows * n_cols;
|
||||
std::vector<std::vector<float>> _dataset(number_of_images, std::vector<float>(image_size));
|
||||
unsigned char pixel = 0;
|
||||
|
||||
for (int i = 0; i < number_of_images; i++) {
|
||||
for (int j = 0; j < image_size; j++) {
|
||||
file.read((char*)&pixel, sizeof(pixel));
|
||||
_dataset[i][j] = pixel / 255.0f;
|
||||
}
|
||||
}
|
||||
return _dataset;
|
||||
}
|
||||
|
||||
std::vector<int> ReadLabels(const std::string& full_path) {
|
||||
std::ifstream file(full_path);
|
||||
|
||||
if (! file.is_open())
|
||||
throw std::runtime_error("Cannot open file `" + full_path + "`!");
|
||||
|
||||
int magic_number = 0;
|
||||
file.read((char *)&magic_number, sizeof(magic_number));
|
||||
magic_number = reverseInt(magic_number);
|
||||
|
||||
if (magic_number != 2049)
|
||||
throw std::runtime_error("Invalid MNIST label file!");
|
||||
|
||||
int number_of_labels = 0;
|
||||
file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels);
|
||||
|
||||
std::vector<int> _dataset(number_of_labels);
|
||||
for (int i = 0; i < number_of_labels; i++) {
|
||||
file.read((char*)&_dataset[i], 1);
|
||||
}
|
||||
|
||||
return _dataset;
|
||||
}
|
||||
|
||||
} // namespace mnist
|
||||
} // namespace datasets
|
||||
|
||||
|
||||
//int main(int argc, const char *argv[]) {
|
||||
//auto images = datasets::mnist::ReadImages("t10k-images-idx3-ubyte");
|
||||
//auto labels = datasets::mnist::ReadLabels("t10k-labels-idx1-ubyte");
|
||||
|
||||
//std::cout
|
||||
//<< "Number of images: " << images.size() << std::endl
|
||||
//<< "Image size: " << images[0].size() << std::endl;
|
||||
|
||||
//for (int i = 0; i < 3; i++) {
|
||||
//for (int j = 0; j < images[i].size(); j++) {
|
||||
//std::cout << images[i][j] << ",";
|
||||
//}
|
||||
//std::cout << " label=" << (int)labels[i] << std::endl;
|
||||
//}
|
||||
//return 0;
|
||||
//}
|
92
src/test.cu
92
src/test.cu
@ -1,26 +1,28 @@
|
||||
|
||||
#include "marian.h"
|
||||
#include "mnist.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
/*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte");*/
|
||||
/*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");*/
|
||||
/*std::cerr << images.size() << " " << images[0].size() << std::endl;*/
|
||||
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
/*
|
||||
Expr x = input(shape={whatevs, 784}, name="X");
|
||||
Expr y = input(shape={whatevs, 10}, name="Y");
|
||||
|
||||
Expr w = param(shape={784, 10}, name="W0");
|
||||
Expr b = param(shape={1, 10}, name="b0");
|
||||
|
||||
Expr n5 = dot(x, w);
|
||||
Expr n6 = n5 + b;
|
||||
Expr lr = softmax(n6, axis=1, name="pred");
|
||||
auto scores = dot(x, w) + b;
|
||||
auto lr = softmax(scores, axis=1, name="pred");
|
||||
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
cerr << "lr=" << lr.Debug() << endl;
|
||||
|
||||
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
|
||||
Tensor tx({500, 784}, 1);
|
||||
Tensor ty({500, 10}, 1);
|
||||
@ -31,51 +33,47 @@ int main(int argc, char** argv) {
|
||||
y = ty;
|
||||
|
||||
graph.forward(500);
|
||||
|
||||
std::cerr << "Result: ";
|
||||
for (auto val : scores.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
std::cerr << "Result: ";
|
||||
for (auto val : lr.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
std::cerr << "Log-likelihood: ";
|
||||
for (auto val : graph.val().shape()) {
|
||||
std::cerr << val << " ";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
|
||||
graph.backward();
|
||||
|
||||
//std::cerr << graph["pred"].val()[0] << std::endl;
|
||||
|
||||
*/
|
||||
|
||||
Expr x = input(shape={whatevs, 2}, name="X");
|
||||
Expr y = input(shape={whatevs, 2}, name="Y");
|
||||
|
||||
Expr w = param(shape={2, 1}, name="W0");
|
||||
Expr b = param(shape={1, 1}, name="b0");
|
||||
|
||||
Expr n5 = dot(x, w);
|
||||
Expr n6 = n5 + b;
|
||||
Expr lr = softmax(n6, axis=1, name="pred");
|
||||
cerr << "lr=" << lr.Debug() << endl;
|
||||
|
||||
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
|
||||
Tensor tx({4, 2}, 1);
|
||||
Tensor ty({4, 1}, 1);
|
||||
cerr << "tx=" << tx.Debug() << endl;
|
||||
cerr << "ty=" << ty.Debug() << endl;
|
||||
|
||||
tx.Load("../examples/xor/train.txt");
|
||||
ty.Load("../examples/xor/label.txt");
|
||||
|
||||
|
||||
//hook0(graph);
|
||||
//graph.autodiff();
|
||||
//std::cerr << graph["cost"].val()[0] << std::endl;
|
||||
#if 0
|
||||
hook0(graph);
|
||||
graph.autodiff();
|
||||
std::cerr << graph["cost"].val()[0] << std::endl;
|
||||
//hook1(graph);
|
||||
//for(auto p : graph.params()) {
|
||||
// auto update = _1 = _1 - alpha * _2;
|
||||
// Element(update, p.val(), p.grad());
|
||||
//}
|
||||
//hook2(graph);
|
||||
//
|
||||
//auto opt = adadelta(cost_function=cost,
|
||||
// eta=0.9, gamma=0.1,
|
||||
// set_batch=set,
|
||||
// before_update=before,
|
||||
// after_update=after,
|
||||
// set_valid=valid,
|
||||
// validation_freq=100,
|
||||
// verbose=1, epochs=3, early_stopping=10);
|
||||
//opt.run();
|
||||
for(auto p : graph.params()) {
|
||||
auto update = _1 = _1 - alpha * _2;
|
||||
Element(update, p.val(), p.grad());
|
||||
}
|
||||
hook2(graph);
|
||||
|
||||
auto opt = adadelta(cost_function=cost,
|
||||
eta=0.9, gamma=0.1,
|
||||
set_batch=set,
|
||||
before_update=before,
|
||||
after_update=after,
|
||||
set_valid=valid,
|
||||
validation_freq=100,
|
||||
verbose=1, epochs=3, early_stopping=10);
|
||||
opt.run();
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user