fudge factor for adagrad

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-16 18:34:38 +02:00
parent 09b7e15b05
commit 257b621db9

View File

@ -26,16 +26,17 @@ class Adagrad {
Adagrad(float eta=0.1) : eta_(eta) {}
void operator()(ExpressionGraph& graph, int batchSize) {
float fudgeFactor = 1e-6;
graph.backprop(batchSize);
if(history_.size() < graph.params().size())
for(auto& param : graph.params())
history_.emplace_back(Tensor(param.grad().shape(), 0));
graph.backprop(batchSize);
auto it = history_.begin();
for(auto& param : graph.params()) {
Element(_1 -= eta_ / Sqrt(_2) * _3, param.val(), *it, param.grad());
Element(_1 += _2 * _2, *it, param.grad());
Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3, param.val(), *it, param.grad());
it++;
}
}