Tidying up optimizers

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-16 23:32:09 +02:00
parent 70f8277eb1
commit c0b676c7c9

View File

@ -22,23 +22,24 @@ class Sgd {
float eta_;
};
// @TODO: Add serialization for historic gradients and parameters
class Adagrad {
public:
Adagrad(float eta=0.01, float eps=10e-8)
: eta_(eta), eps_(eps) {}
void operator()(ExpressionGraph& graph, int batchSize) {
graph.backprop(batchSize);
graph.backprop(batchSize);
if(history_.size() < graph.params().size())
if(gt_.size() < graph.params().size())
for(auto& param : graph.params())
history_.emplace_back(Tensor(param.grad().shape(), 0));
gt_.emplace_back(Tensor(param.grad().shape(), 0));
auto it = history_.begin();
auto gtIt = gt_.begin();
for(auto& param : graph.params()) {
Element(_1 += _2 * _2, *it, param.grad());
Element(_1 += _2 * _2, *gtIt, param.grad());
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
param.val(), *it, param.grad());
param.val(), *gtIt, param.grad());
it++;
}
}
@ -46,9 +47,10 @@ class Adagrad {
private:
float eta_;
float eps_;
std::vector<Tensor> history_;
std::vector<Tensor> gt_;
};
// @TODO: Add serialization for historic gradients and parameters
class Adam {
public:
Adam(float eta=0.01, float beta1=0.999, float beta2=0.999, float eps=10e-8)