Merge ../Marian

This commit is contained in:
Hieu Hoang 2016-09-15 11:02:16 +02:00
commit aa5a10bfcd
5 changed files with 23 additions and 23 deletions

View File

@ -32,8 +32,8 @@ inline Expr zeroes(Args ...args) {
/*********************************************************/ /*********************************************************/
inline Expr sigmoid(Expr a) { inline Expr logit(Expr a) {
return Expr(new SigmoidNodeOp(a)); return Expr(new LogitNodeOp(a));
} }
inline Expr tanh(Expr a) { inline Expr tanh(Expr a) {

View File

@ -76,9 +76,9 @@ struct UnaryNodeOp : public Node {
a_(a) {} a_(a) {}
}; };
struct SigmoidNodeOp : public UnaryNodeOp { struct LogitNodeOp : public UnaryNodeOp {
template <typename ...Args> template <typename ...Args>
SigmoidNodeOp(Args ...args) LogitNodeOp(Args ...args)
: UnaryNodeOp(args...) { } : UnaryNodeOp(args...) { }
void forward() { void forward() {

View File

@ -22,7 +22,7 @@ void ones(Tensor t) {
void randreal(Tensor t) { void randreal(Tensor t) {
std::random_device device; std::random_device device;
std::default_random_engine engine(device()); std::default_random_engine engine(device());
std::uniform_real_distribution<> dist(0, 0.1); std::uniform_real_distribution<> dist(0, 0.01);
auto gen = std::bind(dist, engine); auto gen = std::bind(dist, engine);
std::vector<float> vals(t.size()); std::vector<float> vals(t.size());

View File

@ -221,6 +221,7 @@ class Tensor {
} }
void get(std::vector<float> &vout) const { void get(std::vector<float> &vout) const {
vout.resize(size());
pimpl_->get(vout.begin()); pimpl_->get(vout.begin());
} }
}; };

View File

@ -41,7 +41,7 @@ int main(int argc, char** argv) {
init=[bData](Tensor t) { t.set(bData); }); init=[bData](Tensor t) { t.set(bData); });
auto probs = softmax(dot(x, w) + b, axis=1); auto probs = softmax(dot(x, w) + b, axis=1);
auto graph = -mean(sum(y * log(probs), axis=1), axis=0); auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
std::cerr << "Done." << std::endl; std::cerr << "Done." << std::endl;
@ -51,50 +51,49 @@ int main(int argc, char** argv) {
x = xt << testImages; x = xt << testImages;
y = yt << testLabels; y = yt << testLabels;
graph.forward(BATCH_SIZE); cost.forward(BATCH_SIZE);
auto results = probs.val();
std::vector<float> resultsv(results.size()); std::vector<float> results;
resultsv << results; results << probs.val();
size_t acc = 0; size_t acc = 0;
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) { for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
size_t correct = 0; size_t correct = 0;
size_t probsed = 0; size_t proposed = 0;
for (size_t j = 0; j < LABEL_SIZE; ++j) { for (size_t j = 0; j < LABEL_SIZE; ++j) {
if (testLabels[i+j]) correct = j; if (testLabels[i+j]) correct = j;
if (resultsv[i + j] > resultsv[i + probsed]) probsed = j; if (results[i + j] > results[i + proposed]) proposed = j;
} }
acc += (correct == probsed); acc += (correct == proposed);
} }
std::cerr << "Cost: " << graph.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl; std::cerr << "Cost: " << cost.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
float eta = 0.1; float eta = 0.1;
for (size_t j = 0; j < 10; ++j) { for (size_t j = 0; j < 10; ++j) {
for(size_t i = 0; i < 60; ++i) { for(size_t i = 0; i < 60; ++i) {
graph.backward(); cost.backward();
auto update_rule = _1 -= eta * _2; auto update_rule = _1 -= eta * _2;
Element(update_rule, w.val(), w.grad()); Element(update_rule, w.val(), w.grad());
Element(update_rule, b.val(), b.grad()); Element(update_rule, b.val(), b.grad());
graph.forward(BATCH_SIZE); cost.forward(BATCH_SIZE);
} }
std::cerr << "Epoch: " << j << std::endl; std::cerr << "Epoch: " << j << std::endl;
auto results = probs.val(); std::vector<float> results;
std::vector<float> resultsv(results.size()); results << probs.val();
resultsv << results;
size_t acc = 0; size_t acc = 0;
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) { for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
size_t correct = 0; size_t correct = 0;
size_t probsed = 0; size_t proposed = 0;
for (size_t j = 0; j < LABEL_SIZE; ++j) { for (size_t j = 0; j < LABEL_SIZE; ++j) {
if (testLabels[i+j]) correct = j; if (testLabels[i+j]) correct = j;
if (resultsv[i + j] > resultsv[i + probsed]) probsed = j; if (results[i + j] > results[i + proposed]) proposed = j;
} }
acc += (correct == probsed); acc += (correct == proposed);
} }
std::cerr << "Cost: " << graph.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl; std::cerr << "Cost: " << cost.val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
} }
return 0; return 0;
} }