mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
modularize numerical gradient
This commit is contained in:
parent
0d395604cf
commit
c8c7daf815
@ -49,7 +49,8 @@ struct Chainable {
|
||||
|
||||
virtual const Shape& shape() = 0;
|
||||
virtual const DataType &val() = 0;
|
||||
virtual const DataType &grad() = 0;
|
||||
virtual const DataType &grad() const = 0;
|
||||
virtual DataType &grad() = 0;
|
||||
virtual void setVal(DataType t) {
|
||||
UTIL_THROW2("Tensors can only be assigned to input and parameter nodes");
|
||||
};
|
||||
|
@ -80,11 +80,16 @@ class Node : public Chainable<Tensor>,
|
||||
return val_;
|
||||
};
|
||||
|
||||
virtual const Tensor &grad() {
|
||||
virtual const Tensor &grad() const {
|
||||
UTIL_THROW_IF2(!adj_, "Tensor has not been allocated");
|
||||
return adj_;
|
||||
};
|
||||
|
||||
|
||||
virtual Tensor &grad() {
|
||||
UTIL_THROW_IF2(!adj_, "Tensor has not been allocated");
|
||||
return adj_;
|
||||
};
|
||||
|
||||
virtual const Shape& shape() {
|
||||
return shape_;
|
||||
}
|
||||
|
@ -16,15 +16,20 @@ struct UnaryNodeOp : public Node {
|
||||
|
||||
cerr << "UnaryNodeOp::" << typeid(*this).name() << "::backward_numeric()" << endl;
|
||||
|
||||
std::vector<float> preCalcGradA;
|
||||
std::vector<float> preCalcGradA, diffGradA, numericalGradA;
|
||||
preCalcGradA << a_->grad();
|
||||
//output("preCalcGradA", preCalcGradA);
|
||||
|
||||
// use df/dx to calc grad
|
||||
backward();
|
||||
diffGradA << a_->grad();
|
||||
//cerr << "orig a_->grad()=" << a_->grad().Debug() << endl;
|
||||
|
||||
//a_->grad().set(preCalcGradA);
|
||||
calc_numeric_grad(delta, a_->val(), a_->grad(), preCalcGradA);
|
||||
numericalGradA << a_->grad();
|
||||
|
||||
a_->grad().set(diffGradA);
|
||||
}
|
||||
|
||||
};
|
||||
|
@ -416,7 +416,7 @@ class Tensor {
|
||||
*
|
||||
* @return True or False
|
||||
*/
|
||||
operator bool() {
|
||||
operator bool() const {
|
||||
return pimpl_ != nullptr;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user