diff --git a/src/sgd.h b/src/sgd.h index 65b3c6f1..a977d7f8 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -8,7 +8,7 @@ namespace marian { class Sgd { public: - Sgd(float eta=0.01) : eta_(eta) {} + Sgd(float eta=0.001) : eta_(eta) {} void operator()(ExpressionGraph& graph, int batchSize) { graph.backprop(batchSize); @@ -25,7 +25,7 @@ class Sgd { // @TODO: Add serialization for historic gradients and parameters class Adagrad { public: - Adagrad(float eta=0.01, float eps=10e-8) + Adagrad(float eta=0.001, float eps=10e-8) : eta_(eta), eps_(eps) {} void operator()(ExpressionGraph& graph, int batchSize) { @@ -37,10 +37,11 @@ class Adagrad { auto gtIt = gt_.begin(); for(auto& param : graph.params()) { - Element(_1 += _2 * _2, *gtIt, param.grad()); + Element(_1 += _2 * _2, + *gtIt, param.grad()); Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3, param.val(), *gtIt, param.grad()); - it++; + gtIt++; } } @@ -51,9 +52,10 @@ class Adagrad { }; // @TODO: Add serialization for historic gradients and parameters +// https://arxiv.org/pdf/1412.6980v8.pdf class Adam { public: - Adam(float eta=0.01, float beta1=0.999, float beta2=0.999, float eps=10e-8) + Adam(float eta=0.001, float beta1=0.999, float beta2=0.999, float eps=10e-8) : eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {} void operator()(ExpressionGraph& graph, int batchSize) { diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 94107cc2..ceb262bd 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -27,23 +27,14 @@ ExpressionGraph build_graph() { auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); - //auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE}, - // init=from_vector(wData)), "w"); - //auto b = named(g.param(shape={1, LABEL_SIZE}, - // init=from_vector(bData)), "b"); - auto w1 = named(g.param(shape={IMAGE_SIZE, 100}, - init=uniform()), "w1"); - auto b1 = named(g.param(shape={1, 100}, - init=uniform()), "b1"); - auto w2 = named(g.param(shape={100, LABEL_SIZE}, - init=uniform()), "w2"); - auto b2 = named(g.param(shape={1, LABEL_SIZE}, - init=uniform()), "b2"); - + auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE}, + init=from_vector(wData)), "w"); + auto b = named(g.param(shape={1, LABEL_SIZE}, + init=from_vector(bData)), "b"); + - auto lr = tanh(dot(x, w1) + b1); auto probs = named( - softmax(dot(lr, w2) + b2), //, axis=1), + softmax(dot(x, w) + b), "probs" ); @@ -72,7 +63,7 @@ int main(int argc, char** argv) { g["y"] = (yt << testLabels); Adam opt; - for(size_t j = 0; j < 10; ++j) { + for(size_t j = 0; j < 20; ++j) { for(size_t i = 0; i < 60; ++i) { opt(g, BATCH_SIZE); }