mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Merge branch 'master' of https://github.com/emjotde/Marian
This commit is contained in:
commit
015be2bf63
@ -1,5 +1,3 @@
|
|||||||
The MIT License (MIT)
|
|
||||||
|
|
||||||
Copyright (c) 2016 Marcin Junczys-Dowmunt
|
Copyright (c) 2016 Marcin Junczys-Dowmunt
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
@ -9,8 +7,8 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|||||||
copies of the Software, and to permit persons to whom the Software is
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
furnished to do so, subject to the following conditions:
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
The above copyright notice and this permission notice shall be included in
|
||||||
copies or substantial portions of the Software.
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include "expression_graph.h"
|
#include "expression_graph.h"
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
|
|
||||||
Expr::Expr(ExpressionGraphPtr g, Chainable<Tensor>* chainable)
|
Expr::Expr(ExpressionGraphPtr g, Chainable<Tensor>* chainable)
|
||||||
@ -32,19 +30,10 @@ Expr::operator ChainPtr() {
|
|||||||
|
|
||||||
std::string Expr::Debug() const
|
std::string Expr::Debug() const
|
||||||
{
|
{
|
||||||
stringstream strm;
|
std::stringstream strm;
|
||||||
const Shape &shape = pimpl_->shape();
|
const Shape &shape = pimpl_->shape();
|
||||||
strm << marian::Debug(shape);
|
strm << marian::Debug(shape);
|
||||||
return strm.str();
|
return strm.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////
|
|
||||||
//ExpressionGraph::ExpressionGraph(int cudaDevice)
|
|
||||||
//: stack_(new ChainableStack)
|
|
||||||
//{
|
|
||||||
// std::srand (time(NULL));
|
|
||||||
// cudaSetDevice(0);
|
|
||||||
//
|
|
||||||
//}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -26,16 +26,17 @@ class Adagrad {
|
|||||||
Adagrad(float eta=0.1) : eta_(eta) {}
|
Adagrad(float eta=0.1) : eta_(eta) {}
|
||||||
|
|
||||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||||
|
float fudgeFactor = 1e-6;
|
||||||
|
graph.backprop(batchSize);
|
||||||
|
|
||||||
if(history_.size() < graph.params().size())
|
if(history_.size() < graph.params().size())
|
||||||
for(auto& param : graph.params())
|
for(auto& param : graph.params())
|
||||||
history_.emplace_back(Tensor(param.grad().shape(), 0));
|
history_.emplace_back(Tensor(param.grad().shape(), 0));
|
||||||
|
|
||||||
graph.backprop(batchSize);
|
|
||||||
|
|
||||||
auto it = history_.begin();
|
auto it = history_.begin();
|
||||||
for(auto& param : graph.params()) {
|
for(auto& param : graph.params()) {
|
||||||
Element(_1 -= eta_ / Sqrt(_2) * _3, param.val(), *it, param.grad());
|
|
||||||
Element(_1 += _2 * _2, *it, param.grad());
|
Element(_1 += _2 * _2, *it, param.grad());
|
||||||
|
Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3, param.val(), *it, param.grad());
|
||||||
it++;
|
it++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user