This commit is contained in:
Hieu Hoang 2016-09-13 16:36:45 +02:00
parent 44dad9b45c
commit 686a8bcbd6
2 changed files with 3 additions and 2 deletions

View File

@ -25,13 +25,12 @@ ChainPtr Expr::node() {
void Expr::forward(size_t batchSize) {
UTIL_THROW_IF2(pimpl_.get() != Chainable<Tensor>::stack.back(),
"Trying to call forward on non-root of computation graph");
std::cerr << "forward:" << std::endl;
std::cerr << "a" << std::endl;
for(auto&& v : Chainable<Tensor>::stack) {
v->allocate(batchSize);
}
std::cerr << "f" << std::endl;
for(auto&& v : Chainable<Tensor>::stack)
v->forward();
}
@ -39,6 +38,7 @@ void Expr::forward(size_t batchSize) {
void Expr::backward() {
UTIL_THROW_IF2(pimpl_.get() != Chainable<Tensor>::stack.back(),
"Trying to call backward on non-root of computation graph");
std::cerr << "backward:" << std::endl;
for(auto&& v : Chainable<Tensor>::stack)
v->set_zero_adjoint();

View File

@ -30,6 +30,7 @@ int main(int argc, char** argv) {
y = ty;
graph.forward(500);
graph.backward();
//std::cerr << graph["pred"].val()[0] << std::endl;