diff --git a/src/sgd.h b/src/sgd.h index 7d3f3200..85a4e4af 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -26,16 +26,17 @@ class Adagrad { Adagrad(float eta=0.1) : eta_(eta) {} void operator()(ExpressionGraph& graph, int batchSize) { + float fudgeFactor = 1e-6; + graph.backprop(batchSize); + if(history_.size() < graph.params().size()) for(auto& param : graph.params()) history_.emplace_back(Tensor(param.grad().shape(), 0)); - graph.backprop(batchSize); - auto it = history_.begin(); for(auto& param : graph.params()) { - Element(_1 -= eta_ / Sqrt(_2) * _3, param.val(), *it, param.grad()); Element(_1 += _2 * _2, *it, param.grad()); + Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3, param.val(), *it, param.grad()); it++; } }