mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
modularize numerical gradient 3
This commit is contained in:
parent
40f8999f3a
commit
03ca00ac22
@ -6,8 +6,7 @@ namespace marian {
|
||||
void Node::calc_numeric_grad(
|
||||
Float delta,
|
||||
Tensor input,
|
||||
Tensor grad,
|
||||
const std::vector<float> &prevCalcGrad
|
||||
Tensor grad
|
||||
)
|
||||
{
|
||||
using namespace std;
|
||||
@ -27,8 +26,8 @@ void Node::calc_numeric_grad(
|
||||
//cerr << "input=" << input.Debug() << endl;
|
||||
//cerr << "adj_=" << adj_.Debug() << endl;
|
||||
|
||||
std::vector<float> origGrad(inputSize);
|
||||
thrust::copy(grad.begin(), grad.end(), origGrad.begin());
|
||||
std::vector<float> prevCalcGrad(inputSize);
|
||||
thrust::copy(grad.begin(), grad.end(), prevCalcGrad.begin());
|
||||
//cerr << "origGrad=" << grad.Debug() << endl;
|
||||
//output("diffGrad", diffGrad);
|
||||
|
||||
|
@ -131,8 +131,7 @@ class Node : public Chainable<Tensor>,
|
||||
void calc_numeric_grad(
|
||||
Float delta,
|
||||
Tensor input,
|
||||
Tensor grad,
|
||||
const std::vector<float> &prevCalcGrad
|
||||
Tensor grad
|
||||
);
|
||||
void broadcast(const std::vector<float> &largeVec, std::vector<float> &smallVec);
|
||||
float L2Norm(const std::vector<float> &vec) const;
|
||||
|
@ -40,7 +40,7 @@ struct BinaryNodeOp : public Node {
|
||||
a_->grad().set(preCalcGradA);
|
||||
b_->grad().set(preCalcGradB);
|
||||
|
||||
calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA);
|
||||
calc_numeric_grad(delta, a_->val(), a_->grad());
|
||||
cerr << "numerical a_->grad()=" << a_->grad().Debug() << endl;
|
||||
|
||||
numericalGradA << a_->grad();
|
||||
@ -51,7 +51,7 @@ struct BinaryNodeOp : public Node {
|
||||
a_->grad().set(preCalcGradA);
|
||||
b_->grad().set(preCalcGradB);
|
||||
|
||||
calc_numeric_grad(delta, b_->val(), b_->grad(), preCalcGradB);
|
||||
calc_numeric_grad(delta, b_->val(), b_->grad());
|
||||
cerr << "numerical b_->grad()=" << b_->grad().Debug() << endl;
|
||||
|
||||
numericalGradB << b_->grad();
|
||||
|
@ -27,7 +27,7 @@ struct UnaryNodeOp : public Node {
|
||||
|
||||
a_->grad().set(preCalcGradA);
|
||||
|
||||
calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA);
|
||||
calc_numeric_grad(delta, a_->val(), a_->grad());
|
||||
cerr << "numerical a_->grad()=" << a_->grad().Debug() << endl;
|
||||
|
||||
numericalGradA << a_->grad();
|
||||
|
Loading…
Reference in New Issue
Block a user