diff --git a/src/sgd.h b/src/sgd.h index dbf14552..65b3c6f1 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -22,23 +22,24 @@ class Sgd { float eta_; }; +// @TODO: Add serialization for historic gradients and parameters class Adagrad { public: Adagrad(float eta=0.01, float eps=10e-8) : eta_(eta), eps_(eps) {} void operator()(ExpressionGraph& graph, int batchSize) { - graph.backprop(batchSize); + graph.backprop(batchSize); - if(history_.size() < graph.params().size()) + if(gt_.size() < graph.params().size()) for(auto& param : graph.params()) - history_.emplace_back(Tensor(param.grad().shape(), 0)); + gt_.emplace_back(Tensor(param.grad().shape(), 0)); - auto it = history_.begin(); + auto gtIt = gt_.begin(); for(auto& param : graph.params()) { - Element(_1 += _2 * _2, *it, param.grad()); + Element(_1 += _2 * _2, *gtIt, param.grad()); Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3, - param.val(), *it, param.grad()); + param.val(), *gtIt, param.grad()); it++; } } @@ -46,9 +47,10 @@ class Adagrad { private: float eta_; float eps_; - std::vector history_; + std::vector gt_; }; +// @TODO: Add serialization for historic gradients and parameters class Adam { public: Adam(float eta=0.01, float beta1=0.999, float beta2=0.999, float eps=10e-8)