mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Adam optimizer
This commit is contained in:
parent
cbc29a0ab1
commit
c46012a6c1
51
src/sgd.h
51
src/sgd.h
@ -8,7 +8,7 @@ namespace marian {
|
||||
|
||||
class Sgd {
|
||||
public:
|
||||
Sgd(float eta=0.1) : eta_(eta) {}
|
||||
Sgd(float eta=0.01) : eta_(eta) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
graph.backprop(batchSize);
|
||||
@ -24,10 +24,10 @@ class Sgd {
|
||||
|
||||
class Adagrad {
|
||||
public:
|
||||
Adagrad(float eta=0.1) : eta_(eta) {}
|
||||
Adagrad(float eta=0.01, float eps=10e-8)
|
||||
: eta_(eta), eps_(eps) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
float fudgeFactor = 1e-6;
|
||||
graph.backprop(batchSize);
|
||||
|
||||
if(history_.size() < graph.params().size())
|
||||
@ -37,7 +37,7 @@ class Adagrad {
|
||||
auto it = history_.begin();
|
||||
for(auto& param : graph.params()) {
|
||||
Element(_1 += _2 * _2, *it, param.grad());
|
||||
Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3,
|
||||
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
|
||||
param.val(), *it, param.grad());
|
||||
it++;
|
||||
}
|
||||
@ -45,7 +45,50 @@ class Adagrad {
|
||||
|
||||
private:
|
||||
float eta_;
|
||||
float eps_;
|
||||
std::vector<Tensor> history_;
|
||||
};
|
||||
|
||||
class Adam {
|
||||
public:
|
||||
Adam(float eta=0.01, float beta1=0.999, float beta2=0.999, float eps=10e-8)
|
||||
: eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
graph.backprop(batchSize);
|
||||
|
||||
if(mt_.size() < graph.params().size()) {
|
||||
for(auto& param : graph.params()) {
|
||||
mt_.emplace_back(Tensor(param.grad().shape(), 0));
|
||||
vt_.emplace_back(Tensor(param.grad().shape(), 0));
|
||||
}
|
||||
}
|
||||
|
||||
t_++;
|
||||
float denom1 = 1 - pow(beta1_, t_);
|
||||
float denom2 = 1 - pow(beta2_, t_);
|
||||
|
||||
auto mtIt = mt_.begin();
|
||||
auto vtIt = vt_.begin();
|
||||
for(auto& param : graph.params()) {
|
||||
Element(_1 = beta1_ * _2 + (1 - beta1_) * _3,
|
||||
*mtIt, *mtIt, param.grad());
|
||||
Element(_1 = beta2_ * _2 + (1 - beta2_) * _3 * _3,
|
||||
*vtIt, *vtIt, param.grad());
|
||||
Element(_1 -= eta_ * (_2 / denom1) / (Sqrt(_3 / denom2) + eps_),
|
||||
param.val(), *mtIt, *vtIt);
|
||||
mtIt++; vtIt++;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float eta_;
|
||||
float beta1_;
|
||||
float beta2_;
|
||||
float eps_;
|
||||
size_t t_;
|
||||
std::vector<Tensor> mt_;
|
||||
std::vector<Tensor> vt_;
|
||||
};
|
||||
|
||||
}
|
12
src/test.cu
12
src/test.cu
@ -72,10 +72,10 @@ int main(int argc, char** argv) {
|
||||
Y.emplace_back(g.input(shape={batch_size, output_size}));
|
||||
}
|
||||
|
||||
Expr Wxh = g.param(shape={input_size, hidden_size}, init=uniform(), name="Wxh");
|
||||
Expr Whh = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh");
|
||||
Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh");
|
||||
Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0");
|
||||
Expr Wxh = g.param(shape={input_size, hidden_size}, name="Wxh");
|
||||
Expr Whh = g.param(shape={hidden_size, hidden_size}, name="Whh");
|
||||
Expr bh = g.param(shape={1, hidden_size}, name="bh");
|
||||
Expr h0 = g.param(shape={1, hidden_size}, name="h0");
|
||||
|
||||
// read parallel corpus from file
|
||||
std::fstream sourceFile("../examples/mt/dev/newstest2013.de");
|
||||
@ -94,8 +94,8 @@ int main(int argc, char** argv) {
|
||||
H.emplace_back(tanh(dot(X[t], Wxh) + dot(H[t-1], Whh) + bh));
|
||||
}
|
||||
|
||||
Expr Why = g.param(shape={hidden_size, output_size}, init=uniform(), name="Why");
|
||||
Expr by = g.param(shape={1, output_size}, init=uniform(), name="by");
|
||||
Expr Why = g.param(shape={hidden_size, output_size}, name="Why");
|
||||
Expr by = g.param(shape={1, output_size}, name="by");
|
||||
|
||||
std::cerr << "Building output layer..." << std::endl;
|
||||
std::vector<Expr> Yp;
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include "marian.h"
|
||||
#include "mnist.h"
|
||||
#include "npz_converter.h"
|
||||
#include "sgd.h"
|
||||
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
@ -26,13 +27,23 @@ ExpressionGraph build_graph() {
|
||||
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
|
||||
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
|
||||
|
||||
auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
|
||||
init=from_vector(wData)), "w");
|
||||
auto b = named(g.param(shape={1, LABEL_SIZE},
|
||||
init=from_vector(bData)), "b");
|
||||
//auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
|
||||
// init=from_vector(wData)), "w");
|
||||
//auto b = named(g.param(shape={1, LABEL_SIZE},
|
||||
// init=from_vector(bData)), "b");
|
||||
auto w1 = named(g.param(shape={IMAGE_SIZE, 100},
|
||||
init=uniform()), "w1");
|
||||
auto b1 = named(g.param(shape={1, 100},
|
||||
init=uniform()), "b1");
|
||||
auto w2 = named(g.param(shape={100, LABEL_SIZE},
|
||||
init=uniform()), "w2");
|
||||
auto b2 = named(g.param(shape={1, LABEL_SIZE},
|
||||
init=uniform()), "b2");
|
||||
|
||||
|
||||
auto lr = tanh(dot(x, w1) + b1);
|
||||
auto probs = named(
|
||||
softmax(dot(x, w) + b), //, axis=1),
|
||||
softmax(dot(lr, w2) + b2), //, axis=1),
|
||||
"probs"
|
||||
);
|
||||
|
||||
@ -60,37 +71,16 @@ int main(int argc, char** argv) {
|
||||
g["x"] = (xt << testImages);
|
||||
g["y"] = (yt << testLabels);
|
||||
|
||||
std::cout << g.graphviz() << std::endl;
|
||||
|
||||
g.forward(BATCH_SIZE);
|
||||
|
||||
std::vector<float> results;
|
||||
results << g["probs"].val();
|
||||
|
||||
size_t acc = 0;
|
||||
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
||||
size_t correct = 0;
|
||||
size_t proposed = 0;
|
||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||
if (testLabels[i+j]) correct = j;
|
||||
if (results[i + j] > results[i + proposed]) proposed = j;
|
||||
}
|
||||
acc += (correct == proposed);
|
||||
}
|
||||
std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
||||
|
||||
float eta = 0.1;
|
||||
for (size_t j = 0; j < 10; ++j) {
|
||||
Adam opt;
|
||||
for(size_t j = 0; j < 10; ++j) {
|
||||
for(size_t i = 0; i < 60; ++i) {
|
||||
g.backward();
|
||||
|
||||
auto update_rule = _1 -= eta * _2;
|
||||
for(auto param : g.params())
|
||||
Element(update_rule, param.val(), param.grad());
|
||||
|
||||
g.forward(BATCH_SIZE);
|
||||
opt(g, BATCH_SIZE);
|
||||
}
|
||||
std::cerr << "Epoch: " << j << std::endl;
|
||||
std::cerr << g["cost"].val()[0] << std::endl;
|
||||
}
|
||||
|
||||
//std::cout << g.graphviz() << std::endl;
|
||||
|
||||
std::vector<float> results;
|
||||
results << g["probs"].val();
|
||||
|
||||
@ -105,6 +95,5 @@ int main(int argc, char** argv) {
|
||||
acc += (correct == proposed);
|
||||
}
|
||||
std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user