mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Tidying up optimizers
This commit is contained in:
parent
70f8277eb1
commit
c0b676c7c9
14
src/sgd.h
14
src/sgd.h
@ -22,6 +22,7 @@ class Sgd {
|
||||
float eta_;
|
||||
};
|
||||
|
||||
// @TODO: Add serialization for historic gradients and parameters
|
||||
class Adagrad {
|
||||
public:
|
||||
Adagrad(float eta=0.01, float eps=10e-8)
|
||||
@ -30,15 +31,15 @@ class Adagrad {
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
graph.backprop(batchSize);
|
||||
|
||||
if(history_.size() < graph.params().size())
|
||||
if(gt_.size() < graph.params().size())
|
||||
for(auto& param : graph.params())
|
||||
history_.emplace_back(Tensor(param.grad().shape(), 0));
|
||||
gt_.emplace_back(Tensor(param.grad().shape(), 0));
|
||||
|
||||
auto it = history_.begin();
|
||||
auto gtIt = gt_.begin();
|
||||
for(auto& param : graph.params()) {
|
||||
Element(_1 += _2 * _2, *it, param.grad());
|
||||
Element(_1 += _2 * _2, *gtIt, param.grad());
|
||||
Element(_1 -= eta_ / (Sqrt(_2) + eps_) * _3,
|
||||
param.val(), *it, param.grad());
|
||||
param.val(), *gtIt, param.grad());
|
||||
it++;
|
||||
}
|
||||
}
|
||||
@ -46,9 +47,10 @@ class Adagrad {
|
||||
private:
|
||||
float eta_;
|
||||
float eps_;
|
||||
std::vector<Tensor> history_;
|
||||
std::vector<Tensor> gt_;
|
||||
};
|
||||
|
||||
// @TODO: Add serialization for historic gradients and parameters
|
||||
class Adam {
|
||||
public:
|
||||
Adam(float eta=0.01, float beta1=0.999, float beta2=0.999, float eps=10e-8)
|
||||
|
Loading…
Reference in New Issue
Block a user