mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
demonstrate settings for need for safe softmax
This commit is contained in:
parent
c0b676c7c9
commit
6e90198426
12
src/sgd.h
12
src/sgd.h
@ -8,7 +8,7 @@ namespace marian {
|
||||
|
||||
class Sgd {
|
||||
public:
|
||||
Sgd(float eta=0.01) : eta_(eta) {}
|
||||
Sgd(float eta=0.001) : eta_(eta) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
graph.backprop(batchSize);
|
||||
@ -25,7 +25,7 @@ class Sgd {
|
||||
// @TODO: Add serialization for historic gradients and parameters
|
||||
class Adagrad {
|
||||
public:
|
||||
Adagrad(float eta=0.01, float eps=10e-8)
|
||||
Adagrad(float eta=0.001, float eps=10e-8)
|
||||
: eta_(eta), eps_(eps) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
@ -37,10 +37,11 @@ class Adagrad {
|
||||
|
||||
auto gtIt = gt_.begin();
|
||||
for(auto& param : graph.params()) {
|
||||
Element(_1 += _2 * _2, *gtIt, param.grad());
|
||||
Element(_1 += _2 * _2,
|
||||
*gtIt, param.grad());
|
||||
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
|
||||
param.val(), *gtIt, param.grad());
|
||||
it++;
|
||||
gtIt++;
|
||||
}
|
||||
}
|
||||
|
||||
@ -51,9 +52,10 @@ class Adagrad {
|
||||
};
|
||||
|
||||
// @TODO: Add serialization for historic gradients and parameters
|
||||
// https://arxiv.org/pdf/1412.6980v8.pdf
|
||||
class Adam {
|
||||
public:
|
||||
Adam(float eta=0.01, float beta1=0.999, float beta2=0.999, float eps=10e-8)
|
||||
Adam(float eta=0.001, float beta1=0.999, float beta2=0.999, float eps=10e-8)
|
||||
: eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
|
@ -27,23 +27,14 @@ ExpressionGraph build_graph() {
|
||||
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
|
||||
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
|
||||
|
||||
//auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
|
||||
// init=from_vector(wData)), "w");
|
||||
//auto b = named(g.param(shape={1, LABEL_SIZE},
|
||||
// init=from_vector(bData)), "b");
|
||||
auto w1 = named(g.param(shape={IMAGE_SIZE, 100},
|
||||
init=uniform()), "w1");
|
||||
auto b1 = named(g.param(shape={1, 100},
|
||||
init=uniform()), "b1");
|
||||
auto w2 = named(g.param(shape={100, LABEL_SIZE},
|
||||
init=uniform()), "w2");
|
||||
auto b2 = named(g.param(shape={1, LABEL_SIZE},
|
||||
init=uniform()), "b2");
|
||||
auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
|
||||
init=from_vector(wData)), "w");
|
||||
auto b = named(g.param(shape={1, LABEL_SIZE},
|
||||
init=from_vector(bData)), "b");
|
||||
|
||||
|
||||
auto lr = tanh(dot(x, w1) + b1);
|
||||
auto probs = named(
|
||||
softmax(dot(lr, w2) + b2), //, axis=1),
|
||||
softmax(dot(x, w) + b),
|
||||
"probs"
|
||||
);
|
||||
|
||||
@ -72,7 +63,7 @@ int main(int argc, char** argv) {
|
||||
g["y"] = (yt << testLabels);
|
||||
|
||||
Adam opt;
|
||||
for(size_t j = 0; j < 10; ++j) {
|
||||
for(size_t j = 0; j < 20; ++j) {
|
||||
for(size_t i = 0; i < 60; ++i) {
|
||||
opt(g, BATCH_SIZE);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user