mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Adam optimizer
This commit is contained in:
parent
cbc29a0ab1
commit
c46012a6c1
53
src/sgd.h
53
src/sgd.h
@ -8,7 +8,7 @@ namespace marian {
|
|||||||
|
|
||||||
class Sgd {
|
class Sgd {
|
||||||
public:
|
public:
|
||||||
Sgd(float eta=0.1) : eta_(eta) {}
|
Sgd(float eta=0.01) : eta_(eta) {}
|
||||||
|
|
||||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||||
graph.backprop(batchSize);
|
graph.backprop(batchSize);
|
||||||
@ -24,11 +24,11 @@ class Sgd {
|
|||||||
|
|
||||||
class Adagrad {
|
class Adagrad {
|
||||||
public:
|
public:
|
||||||
Adagrad(float eta=0.1) : eta_(eta) {}
|
Adagrad(float eta=0.01, float eps=10e-8)
|
||||||
|
: eta_(eta), eps_(eps) {}
|
||||||
|
|
||||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||||
float fudgeFactor = 1e-6;
|
graph.backprop(batchSize);
|
||||||
graph.backprop(batchSize);
|
|
||||||
|
|
||||||
if(history_.size() < graph.params().size())
|
if(history_.size() < graph.params().size())
|
||||||
for(auto& param : graph.params())
|
for(auto& param : graph.params())
|
||||||
@ -37,7 +37,7 @@ class Adagrad {
|
|||||||
auto it = history_.begin();
|
auto it = history_.begin();
|
||||||
for(auto& param : graph.params()) {
|
for(auto& param : graph.params()) {
|
||||||
Element(_1 += _2 * _2, *it, param.grad());
|
Element(_1 += _2 * _2, *it, param.grad());
|
||||||
Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3,
|
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
|
||||||
param.val(), *it, param.grad());
|
param.val(), *it, param.grad());
|
||||||
it++;
|
it++;
|
||||||
}
|
}
|
||||||
@ -45,7 +45,50 @@ class Adagrad {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
float eta_;
|
float eta_;
|
||||||
|
float eps_;
|
||||||
std::vector<Tensor> history_;
|
std::vector<Tensor> history_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class Adam {
|
||||||
|
public:
|
||||||
|
Adam(float eta=0.01, float beta1=0.999, float beta2=0.999, float eps=10e-8)
|
||||||
|
: eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {}
|
||||||
|
|
||||||
|
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||||
|
graph.backprop(batchSize);
|
||||||
|
|
||||||
|
if(mt_.size() < graph.params().size()) {
|
||||||
|
for(auto& param : graph.params()) {
|
||||||
|
mt_.emplace_back(Tensor(param.grad().shape(), 0));
|
||||||
|
vt_.emplace_back(Tensor(param.grad().shape(), 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t_++;
|
||||||
|
float denom1 = 1 - pow(beta1_, t_);
|
||||||
|
float denom2 = 1 - pow(beta2_, t_);
|
||||||
|
|
||||||
|
auto mtIt = mt_.begin();
|
||||||
|
auto vtIt = vt_.begin();
|
||||||
|
for(auto& param : graph.params()) {
|
||||||
|
Element(_1 = beta1_ * _2 + (1 - beta1_) * _3,
|
||||||
|
*mtIt, *mtIt, param.grad());
|
||||||
|
Element(_1 = beta2_ * _2 + (1 - beta2_) * _3 * _3,
|
||||||
|
*vtIt, *vtIt, param.grad());
|
||||||
|
Element(_1 -= eta_ * (_2 / denom1) / (Sqrt(_3 / denom2) + eps_),
|
||||||
|
param.val(), *mtIt, *vtIt);
|
||||||
|
mtIt++; vtIt++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
float eta_;
|
||||||
|
float beta1_;
|
||||||
|
float beta2_;
|
||||||
|
float eps_;
|
||||||
|
size_t t_;
|
||||||
|
std::vector<Tensor> mt_;
|
||||||
|
std::vector<Tensor> vt_;
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
12
src/test.cu
12
src/test.cu
@ -72,10 +72,10 @@ int main(int argc, char** argv) {
|
|||||||
Y.emplace_back(g.input(shape={batch_size, output_size}));
|
Y.emplace_back(g.input(shape={batch_size, output_size}));
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr Wxh = g.param(shape={input_size, hidden_size}, init=uniform(), name="Wxh");
|
Expr Wxh = g.param(shape={input_size, hidden_size}, name="Wxh");
|
||||||
Expr Whh = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh");
|
Expr Whh = g.param(shape={hidden_size, hidden_size}, name="Whh");
|
||||||
Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh");
|
Expr bh = g.param(shape={1, hidden_size}, name="bh");
|
||||||
Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0");
|
Expr h0 = g.param(shape={1, hidden_size}, name="h0");
|
||||||
|
|
||||||
// read parallel corpus from file
|
// read parallel corpus from file
|
||||||
std::fstream sourceFile("../examples/mt/dev/newstest2013.de");
|
std::fstream sourceFile("../examples/mt/dev/newstest2013.de");
|
||||||
@ -94,8 +94,8 @@ int main(int argc, char** argv) {
|
|||||||
H.emplace_back(tanh(dot(X[t], Wxh) + dot(H[t-1], Whh) + bh));
|
H.emplace_back(tanh(dot(X[t], Wxh) + dot(H[t-1], Whh) + bh));
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr Why = g.param(shape={hidden_size, output_size}, init=uniform(), name="Why");
|
Expr Why = g.param(shape={hidden_size, output_size}, name="Why");
|
||||||
Expr by = g.param(shape={1, output_size}, init=uniform(), name="by");
|
Expr by = g.param(shape={1, output_size}, name="by");
|
||||||
|
|
||||||
std::cerr << "Building output layer..." << std::endl;
|
std::cerr << "Building output layer..." << std::endl;
|
||||||
std::vector<Expr> Yp;
|
std::vector<Expr> Yp;
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#include "marian.h"
|
#include "marian.h"
|
||||||
#include "mnist.h"
|
#include "mnist.h"
|
||||||
#include "npz_converter.h"
|
#include "npz_converter.h"
|
||||||
|
#include "sgd.h"
|
||||||
|
|
||||||
using namespace marian;
|
using namespace marian;
|
||||||
using namespace keywords;
|
using namespace keywords;
|
||||||
@ -26,13 +27,23 @@ ExpressionGraph build_graph() {
|
|||||||
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
|
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
|
||||||
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
|
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
|
||||||
|
|
||||||
auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
|
//auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
|
||||||
init=from_vector(wData)), "w");
|
// init=from_vector(wData)), "w");
|
||||||
auto b = named(g.param(shape={1, LABEL_SIZE},
|
//auto b = named(g.param(shape={1, LABEL_SIZE},
|
||||||
init=from_vector(bData)), "b");
|
// init=from_vector(bData)), "b");
|
||||||
|
auto w1 = named(g.param(shape={IMAGE_SIZE, 100},
|
||||||
|
init=uniform()), "w1");
|
||||||
|
auto b1 = named(g.param(shape={1, 100},
|
||||||
|
init=uniform()), "b1");
|
||||||
|
auto w2 = named(g.param(shape={100, LABEL_SIZE},
|
||||||
|
init=uniform()), "w2");
|
||||||
|
auto b2 = named(g.param(shape={1, LABEL_SIZE},
|
||||||
|
init=uniform()), "b2");
|
||||||
|
|
||||||
|
|
||||||
|
auto lr = tanh(dot(x, w1) + b1);
|
||||||
auto probs = named(
|
auto probs = named(
|
||||||
softmax(dot(x, w) + b), //, axis=1),
|
softmax(dot(lr, w2) + b2), //, axis=1),
|
||||||
"probs"
|
"probs"
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -60,10 +71,16 @@ int main(int argc, char** argv) {
|
|||||||
g["x"] = (xt << testImages);
|
g["x"] = (xt << testImages);
|
||||||
g["y"] = (yt << testLabels);
|
g["y"] = (yt << testLabels);
|
||||||
|
|
||||||
std::cout << g.graphviz() << std::endl;
|
Adam opt;
|
||||||
|
for(size_t j = 0; j < 10; ++j) {
|
||||||
|
for(size_t i = 0; i < 60; ++i) {
|
||||||
|
opt(g, BATCH_SIZE);
|
||||||
|
}
|
||||||
|
std::cerr << g["cost"].val()[0] << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
//std::cout << g.graphviz() << std::endl;
|
||||||
|
|
||||||
g.forward(BATCH_SIZE);
|
|
||||||
|
|
||||||
std::vector<float> results;
|
std::vector<float> results;
|
||||||
results << g["probs"].val();
|
results << g["probs"].val();
|
||||||
|
|
||||||
@ -78,33 +95,5 @@ int main(int argc, char** argv) {
|
|||||||
acc += (correct == proposed);
|
acc += (correct == proposed);
|
||||||
}
|
}
|
||||||
std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
||||||
|
|
||||||
float eta = 0.1;
|
|
||||||
for (size_t j = 0; j < 10; ++j) {
|
|
||||||
for(size_t i = 0; i < 60; ++i) {
|
|
||||||
g.backward();
|
|
||||||
|
|
||||||
auto update_rule = _1 -= eta * _2;
|
|
||||||
for(auto param : g.params())
|
|
||||||
Element(update_rule, param.val(), param.grad());
|
|
||||||
|
|
||||||
g.forward(BATCH_SIZE);
|
|
||||||
}
|
|
||||||
std::cerr << "Epoch: " << j << std::endl;
|
|
||||||
std::vector<float> results;
|
|
||||||
results << g["probs"].val();
|
|
||||||
|
|
||||||
size_t acc = 0;
|
|
||||||
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
|
||||||
size_t correct = 0;
|
|
||||||
size_t proposed = 0;
|
|
||||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
|
||||||
if (testLabels[i+j]) correct = j;
|
|
||||||
if (results[i + j] > results[i + proposed]) proposed = j;
|
|
||||||
}
|
|
||||||
acc += (correct == proposed);
|
|
||||||
}
|
|
||||||
std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
|
||||||
}
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user