mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
backward_numeric()
This commit is contained in:
parent
39eb9f50f5
commit
8b55721206
@ -34,6 +34,8 @@ struct Chainable {
|
||||
virtual ~Chainable() { }
|
||||
virtual void forward() { }
|
||||
virtual void backward() { }
|
||||
virtual void backward_numeric() { }
|
||||
|
||||
virtual void check() { }
|
||||
virtual void init_dependent() { }
|
||||
virtual void set_zero_adjoint() { }
|
||||
|
@ -127,6 +127,19 @@ class ExpressionGraph {
|
||||
(*it)->backward();
|
||||
}
|
||||
|
||||
void backward_numeric() {
|
||||
for(auto&& v : *stack_)
|
||||
v->set_zero_adjoint();
|
||||
|
||||
typedef typename ChainableStack::reverse_iterator It;
|
||||
stack_->back()->init_dependent();
|
||||
for(It it = stack_->rbegin(); it != stack_->rend(); ++it) {
|
||||
Chainable<Tensor> *chainable = *it;
|
||||
//chainable->backward();
|
||||
chainable->backward_numeric();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns a string representing this expression graph in <code>graphviz</code> notation.
|
||||
*
|
||||
|
@ -114,6 +114,10 @@ struct UnaryNodeOp : public Node {
|
||||
UnaryNodeOp(ChainPtr a, Args ...args)
|
||||
: Node(keywords::shape=a->shape(), //@TODO: Check keywords?
|
||||
args...), a_(a) {}
|
||||
|
||||
void backward_numeric() {
|
||||
backward();
|
||||
}
|
||||
};
|
||||
|
||||
struct LogitNodeOp : public UnaryNodeOp {
|
||||
@ -289,7 +293,7 @@ struct NegNodeOp : public UnaryNodeOp {
|
||||
void backward() {
|
||||
Element(_1 += -_2, a_->grad(), adj_);
|
||||
}
|
||||
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
@ -308,6 +312,11 @@ struct BinaryNodeOp : public Node {
|
||||
template <typename ...Args>
|
||||
BinaryNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: Node(args...), a_(a), b_(b) {}
|
||||
|
||||
void backward_numeric() {
|
||||
backward();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/*** Matrix Product ***/
|
||||
@ -450,7 +459,7 @@ struct DivNodeOp : public BinaryNodeOp {
|
||||
Element(_1 -= _2 * _3 / (_4 * _4),
|
||||
b_->grad(), adj_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"÷\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
@ -519,7 +528,8 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"cross_entropy\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
|
@ -54,7 +54,8 @@ int main(int argc, char** argv)
|
||||
|
||||
// train
|
||||
g.forward(batch_size);
|
||||
g.backward();
|
||||
//g.backward();
|
||||
g.backward_numeric();
|
||||
|
||||
std::cout << g.graphviz() << std::endl;
|
||||
|
||||
@ -67,10 +68,4 @@ int main(int argc, char** argv)
|
||||
std::cerr << "outGrad=" << outGrad.Debug() << std::endl;
|
||||
|
||||
|
||||
Tensor costTensor = cost.val();
|
||||
std::cerr << "costTensor=" << costTensor.Debug() << std::endl;
|
||||
|
||||
Tensor costGrad = cost.grad();
|
||||
std::cerr << "costGrad=" << costGrad.Debug() << std::endl;
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user