This commit is contained in:
Lane Schwartz 2016-09-16 19:21:20 +02:00
commit 015be2bf63
3 changed files with 7 additions and 19 deletions

View File

@ -1,5 +1,3 @@
The MIT License (MIT)
Copyright (c) 2016 Marcin Junczys-Dowmunt
Permission is hereby granted, free of charge, to any person obtaining a copy
@ -9,8 +7,8 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

View File

@ -1,8 +1,6 @@
#include <sstream>
#include "expression_graph.h"
using namespace std;
namespace marian {
Expr::Expr(ExpressionGraphPtr g, Chainable<Tensor>* chainable)
@ -32,19 +30,10 @@ Expr::operator ChainPtr() {
std::string Expr::Debug() const
{
stringstream strm;
std::stringstream strm;
const Shape &shape = pimpl_->shape();
strm << marian::Debug(shape);
return strm.str();
}
///////////////////////////////////////////////////////
//ExpressionGraph::ExpressionGraph(int cudaDevice)
//: stack_(new ChainableStack)
//{
// std::srand (time(NULL));
// cudaSetDevice(0);
//
//}
}

View File

@ -26,16 +26,17 @@ class Adagrad {
Adagrad(float eta=0.1) : eta_(eta) {}
void operator()(ExpressionGraph& graph, int batchSize) {
float fudgeFactor = 1e-6;
graph.backprop(batchSize);
if(history_.size() < graph.params().size())
for(auto& param : graph.params())
history_.emplace_back(Tensor(param.grad().shape(), 0));
graph.backprop(batchSize);
auto it = history_.begin();
for(auto& param : graph.params()) {
Element(_1 -= eta_ / Sqrt(_2) * _3, param.val(), *it, param.grad());
Element(_1 += _2 * _2, *it, param.grad());
Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3, param.val(), *it, param.grad());
it++;
}
}