mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
demonstrate settings for need for safe softmax
This commit is contained in:
parent
c0b676c7c9
commit
6e90198426
12
src/sgd.h
12
src/sgd.h
@ -8,7 +8,7 @@ namespace marian {
|
|||||||
|
|
||||||
class Sgd {
|
class Sgd {
|
||||||
public:
|
public:
|
||||||
Sgd(float eta=0.01) : eta_(eta) {}
|
Sgd(float eta=0.001) : eta_(eta) {}
|
||||||
|
|
||||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||||
graph.backprop(batchSize);
|
graph.backprop(batchSize);
|
||||||
@ -25,7 +25,7 @@ class Sgd {
|
|||||||
// @TODO: Add serialization for historic gradients and parameters
|
// @TODO: Add serialization for historic gradients and parameters
|
||||||
class Adagrad {
|
class Adagrad {
|
||||||
public:
|
public:
|
||||||
Adagrad(float eta=0.01, float eps=10e-8)
|
Adagrad(float eta=0.001, float eps=10e-8)
|
||||||
: eta_(eta), eps_(eps) {}
|
: eta_(eta), eps_(eps) {}
|
||||||
|
|
||||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||||
@ -37,10 +37,11 @@ class Adagrad {
|
|||||||
|
|
||||||
auto gtIt = gt_.begin();
|
auto gtIt = gt_.begin();
|
||||||
for(auto& param : graph.params()) {
|
for(auto& param : graph.params()) {
|
||||||
Element(_1 += _2 * _2, *gtIt, param.grad());
|
Element(_1 += _2 * _2,
|
||||||
|
*gtIt, param.grad());
|
||||||
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
|
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
|
||||||
param.val(), *gtIt, param.grad());
|
param.val(), *gtIt, param.grad());
|
||||||
it++;
|
gtIt++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,9 +52,10 @@ class Adagrad {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// @TODO: Add serialization for historic gradients and parameters
|
// @TODO: Add serialization for historic gradients and parameters
|
||||||
|
// https://arxiv.org/pdf/1412.6980v8.pdf
|
||||||
class Adam {
|
class Adam {
|
||||||
public:
|
public:
|
||||||
Adam(float eta=0.01, float beta1=0.999, float beta2=0.999, float eps=10e-8)
|
Adam(float eta=0.001, float beta1=0.999, float beta2=0.999, float eps=10e-8)
|
||||||
: eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {}
|
: eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {}
|
||||||
|
|
||||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||||
|
@ -27,23 +27,14 @@ ExpressionGraph build_graph() {
|
|||||||
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
|
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
|
||||||
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
|
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
|
||||||
|
|
||||||
//auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
|
auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
|
||||||
// init=from_vector(wData)), "w");
|
init=from_vector(wData)), "w");
|
||||||
//auto b = named(g.param(shape={1, LABEL_SIZE},
|
auto b = named(g.param(shape={1, LABEL_SIZE},
|
||||||
// init=from_vector(bData)), "b");
|
init=from_vector(bData)), "b");
|
||||||
auto w1 = named(g.param(shape={IMAGE_SIZE, 100},
|
|
||||||
init=uniform()), "w1");
|
|
||||||
auto b1 = named(g.param(shape={1, 100},
|
|
||||||
init=uniform()), "b1");
|
|
||||||
auto w2 = named(g.param(shape={100, LABEL_SIZE},
|
|
||||||
init=uniform()), "w2");
|
|
||||||
auto b2 = named(g.param(shape={1, LABEL_SIZE},
|
|
||||||
init=uniform()), "b2");
|
|
||||||
|
|
||||||
|
|
||||||
auto lr = tanh(dot(x, w1) + b1);
|
|
||||||
auto probs = named(
|
auto probs = named(
|
||||||
softmax(dot(lr, w2) + b2), //, axis=1),
|
softmax(dot(x, w) + b),
|
||||||
"probs"
|
"probs"
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -72,7 +63,7 @@ int main(int argc, char** argv) {
|
|||||||
g["y"] = (yt << testLabels);
|
g["y"] = (yt << testLabels);
|
||||||
|
|
||||||
Adam opt;
|
Adam opt;
|
||||||
for(size_t j = 0; j < 10; ++j) {
|
for(size_t j = 0; j < 20; ++j) {
|
||||||
for(size_t i = 0; i < 60; ++i) {
|
for(size_t i = 0; i < 60; ++i) {
|
||||||
opt(g, BATCH_SIZE);
|
opt(g, BATCH_SIZE);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user