demonstrate settings for need for safe softmax

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-16 23:45:35 +02:00
parent c0b676c7c9
commit 6e90198426
2 changed files with 14 additions and 21 deletions

View File

@ -8,7 +8,7 @@ namespace marian {
class Sgd {
public:
Sgd(float eta=0.01) : eta_(eta) {}
Sgd(float eta=0.001) : eta_(eta) {}
void operator()(ExpressionGraph& graph, int batchSize) {
graph.backprop(batchSize);
@ -25,7 +25,7 @@ class Sgd {
// @TODO: Add serialization for historic gradients and parameters
class Adagrad {
public:
Adagrad(float eta=0.01, float eps=10e-8)
Adagrad(float eta=0.001, float eps=10e-8)
: eta_(eta), eps_(eps) {}
void operator()(ExpressionGraph& graph, int batchSize) {
@ -37,10 +37,11 @@ class Adagrad {
auto gtIt = gt_.begin();
for(auto& param : graph.params()) {
Element(_1 += _2 * _2, *gtIt, param.grad());
Element(_1 += _2 * _2,
*gtIt, param.grad());
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
param.val(), *gtIt, param.grad());
it++;
gtIt++;
}
}
@ -51,9 +52,10 @@ class Adagrad {
};
// @TODO: Add serialization for historic gradients and parameters
// https://arxiv.org/pdf/1412.6980v8.pdf
class Adam {
public:
Adam(float eta=0.01, float beta1=0.999, float beta2=0.999, float eps=10e-8)
Adam(float eta=0.001, float beta1=0.999, float beta2=0.999, float eps=10e-8)
: eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {}
void operator()(ExpressionGraph& graph, int batchSize) {

View File

@ -27,23 +27,14 @@ ExpressionGraph build_graph() {
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
//auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
// init=from_vector(wData)), "w");
//auto b = named(g.param(shape={1, LABEL_SIZE},
// init=from_vector(bData)), "b");
auto w1 = named(g.param(shape={IMAGE_SIZE, 100},
init=uniform()), "w1");
auto b1 = named(g.param(shape={1, 100},
init=uniform()), "b1");
auto w2 = named(g.param(shape={100, LABEL_SIZE},
init=uniform()), "w2");
auto b2 = named(g.param(shape={1, LABEL_SIZE},
init=uniform()), "b2");
auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
init=from_vector(wData)), "w");
auto b = named(g.param(shape={1, LABEL_SIZE},
init=from_vector(bData)), "b");
auto lr = tanh(dot(x, w1) + b1);
auto probs = named(
softmax(dot(lr, w2) + b2), //, axis=1),
softmax(dot(x, w) + b),
"probs"
);
@ -72,7 +63,7 @@ int main(int argc, char** argv) {
g["y"] = (yt << testLabels);
Adam opt;
for(size_t j = 0; j < 10; ++j) {
for(size_t j = 0; j < 20; ++j) {
for(size_t i = 0; i < 60; ++i) {
opt(g, BATCH_SIZE);
}