From c46012a6c1777d36341260e3f71207aabe2b7981 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Fri, 16 Sep 2016 23:26:54 +0200 Subject: [PATCH] Adam optimizer --- src/sgd.h | 53 +++++++++++++++++++++++++++++++++---- src/test.cu | 12 ++++----- src/validate_mnist.cu | 61 ++++++++++++++++++------------------------- 3 files changed, 79 insertions(+), 47 deletions(-) diff --git a/src/sgd.h b/src/sgd.h index fe0470b1..dbf14552 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -8,7 +8,7 @@ namespace marian { class Sgd { public: - Sgd(float eta=0.1) : eta_(eta) {} + Sgd(float eta=0.01) : eta_(eta) {} void operator()(ExpressionGraph& graph, int batchSize) { graph.backprop(batchSize); @@ -24,11 +24,11 @@ class Sgd { class Adagrad { public: - Adagrad(float eta=0.1) : eta_(eta) {} + Adagrad(float eta=0.01, float eps=10e-8) + : eta_(eta), eps_(eps) {} void operator()(ExpressionGraph& graph, int batchSize) { - float fudgeFactor = 1e-6; - graph.backprop(batchSize); + graph.backprop(batchSize); if(history_.size() < graph.params().size()) for(auto& param : graph.params()) @@ -37,7 +37,7 @@ class Adagrad { auto it = history_.begin(); for(auto& param : graph.params()) { Element(_1 += _2 * _2, *it, param.grad()); - Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3, + Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3, param.val(), *it, param.grad()); it++; } @@ -45,7 +45,50 @@ class Adagrad { private: float eta_; + float eps_; std::vector history_; }; +class Adam { + public: + Adam(float eta=0.01, float beta1=0.999, float beta2=0.999, float eps=10e-8) + : eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {} + + void operator()(ExpressionGraph& graph, int batchSize) { + graph.backprop(batchSize); + + if(mt_.size() < graph.params().size()) { + for(auto& param : graph.params()) { + mt_.emplace_back(Tensor(param.grad().shape(), 0)); + vt_.emplace_back(Tensor(param.grad().shape(), 0)); + } + } + + t_++; + float denom1 = 1 - pow(beta1_, t_); + float denom2 = 1 - pow(beta2_, t_); + + auto mtIt = mt_.begin(); + auto vtIt = vt_.begin(); + for(auto& param : graph.params()) { + Element(_1 = beta1_ * _2 + (1 - beta1_) * _3, + *mtIt, *mtIt, param.grad()); + Element(_1 = beta2_ * _2 + (1 - beta2_) * _3 * _3, + *vtIt, *vtIt, param.grad()); + Element(_1 -= eta_ * (_2 / denom1) / (Sqrt(_3 / denom2) + eps_), + param.val(), *mtIt, *vtIt); + mtIt++; vtIt++; + } + } + + private: + float eta_; + float beta1_; + float beta2_; + float eps_; + size_t t_; + std::vector mt_; + std::vector vt_; +}; + } \ No newline at end of file diff --git a/src/test.cu b/src/test.cu index 8c7dfc54..0f6a6334 100644 --- a/src/test.cu +++ b/src/test.cu @@ -72,10 +72,10 @@ int main(int argc, char** argv) { Y.emplace_back(g.input(shape={batch_size, output_size})); } - Expr Wxh = g.param(shape={input_size, hidden_size}, init=uniform(), name="Wxh"); - Expr Whh = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh"); - Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh"); - Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0"); + Expr Wxh = g.param(shape={input_size, hidden_size}, name="Wxh"); + Expr Whh = g.param(shape={hidden_size, hidden_size}, name="Whh"); + Expr bh = g.param(shape={1, hidden_size}, name="bh"); + Expr h0 = g.param(shape={1, hidden_size}, name="h0"); // read parallel corpus from file std::fstream sourceFile("../examples/mt/dev/newstest2013.de"); @@ -94,8 +94,8 @@ int main(int argc, char** argv) { H.emplace_back(tanh(dot(X[t], Wxh) + dot(H[t-1], Whh) + bh)); } - Expr Why = g.param(shape={hidden_size, output_size}, init=uniform(), name="Why"); - Expr by = g.param(shape={1, output_size}, init=uniform(), name="by"); + Expr Why = g.param(shape={hidden_size, output_size}, name="Why"); + Expr by = g.param(shape={1, output_size}, name="by"); std::cerr << "Building output layer..." << std::endl; std::vector Yp; diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 690d6f40..94107cc2 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -2,6 +2,7 @@ #include "marian.h" #include "mnist.h" #include "npz_converter.h" +#include "sgd.h" using namespace marian; using namespace keywords; @@ -26,13 +27,23 @@ ExpressionGraph build_graph() { auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); - auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE}, - init=from_vector(wData)), "w"); - auto b = named(g.param(shape={1, LABEL_SIZE}, - init=from_vector(bData)), "b"); + //auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE}, + // init=from_vector(wData)), "w"); + //auto b = named(g.param(shape={1, LABEL_SIZE}, + // init=from_vector(bData)), "b"); + auto w1 = named(g.param(shape={IMAGE_SIZE, 100}, + init=uniform()), "w1"); + auto b1 = named(g.param(shape={1, 100}, + init=uniform()), "b1"); + auto w2 = named(g.param(shape={100, LABEL_SIZE}, + init=uniform()), "w2"); + auto b2 = named(g.param(shape={1, LABEL_SIZE}, + init=uniform()), "b2"); + + auto lr = tanh(dot(x, w1) + b1); auto probs = named( - softmax(dot(x, w) + b), //, axis=1), + softmax(dot(lr, w2) + b2), //, axis=1), "probs" ); @@ -60,10 +71,16 @@ int main(int argc, char** argv) { g["x"] = (xt << testImages); g["y"] = (yt << testLabels); - std::cout << g.graphviz() << std::endl; + Adam opt; + for(size_t j = 0; j < 10; ++j) { + for(size_t i = 0; i < 60; ++i) { + opt(g, BATCH_SIZE); + } + std::cerr << g["cost"].val()[0] << std::endl; + } + + //std::cout << g.graphviz() << std::endl; - g.forward(BATCH_SIZE); - std::vector results; results << g["probs"].val(); @@ -78,33 +95,5 @@ int main(int argc, char** argv) { acc += (correct == proposed); } std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl; - - float eta = 0.1; - for (size_t j = 0; j < 10; ++j) { - for(size_t i = 0; i < 60; ++i) { - g.backward(); - - auto update_rule = _1 -= eta * _2; - for(auto param : g.params()) - Element(update_rule, param.val(), param.grad()); - - g.forward(BATCH_SIZE); - } - std::cerr << "Epoch: " << j << std::endl; - std::vector results; - results << g["probs"].val(); - - size_t acc = 0; - for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) { - size_t correct = 0; - size_t proposed = 0; - for (size_t j = 0; j < LABEL_SIZE; ++j) { - if (testLabels[i+j]) correct = j; - if (results[i + j] > results[i + proposed]) proposed = j; - } - acc += (correct == proposed); - } - std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl; - } return 0; }