This commit is contained in:
Lane Schwartz 2016-09-16 19:21:20 +02:00
commit 015be2bf63
3 changed files with 7 additions and 19 deletions

View File

@ -1,5 +1,3 @@
The MIT License (MIT)
Copyright (c) 2016 Marcin Junczys-Dowmunt Copyright (c) 2016 Marcin Junczys-Dowmunt
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
@ -9,8 +7,8 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions: furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all The above copyright notice and this permission notice shall be included in
copies or substantial portions of the Software. all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

View File

@ -1,8 +1,6 @@
#include <sstream> #include <sstream>
#include "expression_graph.h" #include "expression_graph.h"
using namespace std;
namespace marian { namespace marian {
Expr::Expr(ExpressionGraphPtr g, Chainable<Tensor>* chainable) Expr::Expr(ExpressionGraphPtr g, Chainable<Tensor>* chainable)
@ -32,19 +30,10 @@ Expr::operator ChainPtr() {
std::string Expr::Debug() const std::string Expr::Debug() const
{ {
stringstream strm; std::stringstream strm;
const Shape &shape = pimpl_->shape(); const Shape &shape = pimpl_->shape();
strm << marian::Debug(shape); strm << marian::Debug(shape);
return strm.str(); return strm.str();
} }
///////////////////////////////////////////////////////
//ExpressionGraph::ExpressionGraph(int cudaDevice)
//: stack_(new ChainableStack)
//{
// std::srand (time(NULL));
// cudaSetDevice(0);
//
//}
} }

View File

@ -26,16 +26,17 @@ class Adagrad {
Adagrad(float eta=0.1) : eta_(eta) {} Adagrad(float eta=0.1) : eta_(eta) {}
void operator()(ExpressionGraph& graph, int batchSize) { void operator()(ExpressionGraph& graph, int batchSize) {
float fudgeFactor = 1e-6;
graph.backprop(batchSize);
if(history_.size() < graph.params().size()) if(history_.size() < graph.params().size())
for(auto& param : graph.params()) for(auto& param : graph.params())
history_.emplace_back(Tensor(param.grad().shape(), 0)); history_.emplace_back(Tensor(param.grad().shape(), 0));
graph.backprop(batchSize);
auto it = history_.begin(); auto it = history_.begin();
for(auto& param : graph.params()) { for(auto& param : graph.params()) {
Element(_1 -= eta_ / Sqrt(_2) * _3, param.val(), *it, param.grad());
Element(_1 += _2 * _2, *it, param.grad()); Element(_1 += _2 * _2, *it, param.grad());
Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3, param.val(), *it, param.grad());
it++; it++;
} }
} }