backward_numeric()

This commit is contained in:
Hieu Hoang 2016-09-19 15:22:44 +01:00
parent 39eb9f50f5
commit 8b55721206
4 changed files with 30 additions and 10 deletions

View File

@ -34,6 +34,8 @@ struct Chainable {
virtual ~Chainable() { }
virtual void forward() { }
virtual void backward() { }
virtual void backward_numeric() { }
virtual void check() { }
virtual void init_dependent() { }
virtual void set_zero_adjoint() { }

View File

@ -127,6 +127,19 @@ class ExpressionGraph {
(*it)->backward();
}
void backward_numeric() {
for(auto&& v : *stack_)
v->set_zero_adjoint();
typedef typename ChainableStack::reverse_iterator It;
stack_->back()->init_dependent();
for(It it = stack_->rbegin(); it != stack_->rend(); ++it) {
Chainable<Tensor> *chainable = *it;
//chainable->backward();
chainable->backward_numeric();
}
}
/**
* @brief Returns a string representing this expression graph in <code>graphviz</code> notation.
*

View File

@ -114,6 +114,10 @@ struct UnaryNodeOp : public Node {
UnaryNodeOp(ChainPtr a, Args ...args)
: Node(keywords::shape=a->shape(), //@TODO: Check keywords?
args...), a_(a) {}
void backward_numeric() {
backward();
}
};
struct LogitNodeOp : public UnaryNodeOp {
@ -289,7 +293,7 @@ struct NegNodeOp : public UnaryNodeOp {
void backward() {
Element(_1 += -_2, a_->grad(), adj_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
@ -308,6 +312,11 @@ struct BinaryNodeOp : public Node {
template <typename ...Args>
BinaryNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: Node(args...), a_(a), b_(b) {}
void backward_numeric() {
backward();
}
};
/*** Matrix Product ***/
@ -450,7 +459,7 @@ struct DivNodeOp : public BinaryNodeOp {
Element(_1 -= _2 * _3 / (_4 * _4),
b_->grad(), adj_, a_->val(), b_->val());
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"÷\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
@ -519,7 +528,8 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"cross_entropy\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};

View File

@ -54,7 +54,8 @@ int main(int argc, char** argv)
// train
g.forward(batch_size);
g.backward();
//g.backward();
g.backward_numeric();
std::cout << g.graphviz() << std::endl;
@ -67,10 +68,4 @@ int main(int argc, char** argv)
std::cerr << "outGrad=" << outGrad.Debug() << std::endl;
Tensor costTensor = cost.val();
std::cerr << "costTensor=" << costTensor.Debug() << std::endl;
Tensor costGrad = cost.grad();
std::cerr << "costGrad=" << costGrad.Debug() << std::endl;
}