modularize numerical gradient 3

This commit is contained in:
Hieu Hoang 2016-09-23 11:15:18 +01:00
parent 40f8999f3a
commit 03ca00ac22
4 changed files with 7 additions and 9 deletions

View File

@ -6,8 +6,7 @@ namespace marian {
void Node::calc_numeric_grad(
Float delta,
Tensor input,
Tensor grad,
const std::vector<float> &prevCalcGrad
Tensor grad
)
{
using namespace std;
@ -27,8 +26,8 @@ void Node::calc_numeric_grad(
//cerr << "input=" << input.Debug() << endl;
//cerr << "adj_=" << adj_.Debug() << endl;
std::vector<float> origGrad(inputSize);
thrust::copy(grad.begin(), grad.end(), origGrad.begin());
std::vector<float> prevCalcGrad(inputSize);
thrust::copy(grad.begin(), grad.end(), prevCalcGrad.begin());
//cerr << "origGrad=" << grad.Debug() << endl;
//output("diffGrad", diffGrad);

View File

@ -131,8 +131,7 @@ class Node : public Chainable<Tensor>,
void calc_numeric_grad(
Float delta,
Tensor input,
Tensor grad,
const std::vector<float> &prevCalcGrad
Tensor grad
);
void broadcast(const std::vector<float> &largeVec, std::vector<float> &smallVec);
float L2Norm(const std::vector<float> &vec) const;

View File

@ -40,7 +40,7 @@ struct BinaryNodeOp : public Node {
a_->grad().set(preCalcGradA);
b_->grad().set(preCalcGradB);
calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA);
calc_numeric_grad(delta, a_->val(), a_->grad());
cerr << "numerical a_->grad()=" << a_->grad().Debug() << endl;
numericalGradA << a_->grad();
@ -51,7 +51,7 @@ struct BinaryNodeOp : public Node {
a_->grad().set(preCalcGradA);
b_->grad().set(preCalcGradB);
calc_numeric_grad(delta, b_->val(), b_->grad(), preCalcGradB);
calc_numeric_grad(delta, b_->val(), b_->grad());
cerr << "numerical b_->grad()=" << b_->grad().Debug() << endl;
numericalGradB << b_->grad();

View File

@ -27,7 +27,7 @@ struct UnaryNodeOp : public Node {
a_->grad().set(preCalcGradA);
calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA);
calc_numeric_grad(delta, a_->val(), a_->grad());
cerr << "numerical a_->grad()=" << a_->grad().Debug() << endl;
numericalGradA << a_->grad();