mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
fudge factor for adagrad
This commit is contained in:
parent
09b7e15b05
commit
257b621db9
@ -26,16 +26,17 @@ class Adagrad {
|
||||
Adagrad(float eta=0.1) : eta_(eta) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
float fudgeFactor = 1e-6;
|
||||
graph.backprop(batchSize);
|
||||
|
||||
if(history_.size() < graph.params().size())
|
||||
for(auto& param : graph.params())
|
||||
history_.emplace_back(Tensor(param.grad().shape(), 0));
|
||||
|
||||
graph.backprop(batchSize);
|
||||
|
||||
auto it = history_.begin();
|
||||
for(auto& param : graph.params()) {
|
||||
Element(_1 -= eta_ / Sqrt(_2) * _3, param.val(), *it, param.grad());
|
||||
Element(_1 += _2 * _2, *it, param.grad());
|
||||
Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3, param.val(), *it, param.grad());
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user