diff --git a/src/chainable.h b/src/chainable.h index 9fe6d208..885efdbd 100644 --- a/src/chainable.h +++ b/src/chainable.h @@ -17,6 +17,8 @@ struct Chainable { virtual void set_zero_adjoint() { } virtual void allocate(size_t) = 0; + virtual std::string graphviz() = 0; + virtual const Shape& shape() = 0; virtual DataType &val() = 0; diff --git a/src/expression_graph.h b/src/expression_graph.h index f0d5f233..7a766679 100644 --- a/src/expression_graph.h +++ b/src/expression_graph.h @@ -50,6 +50,18 @@ class ExpressionGraph { v->forward(); } + std::string graphviz() { + std::stringstream ss; + ss << "digraph ExpressionGraph {" << std::endl; + ss << "rankdir=BT" << std::endl; + + typedef typename ChainableStack::reverse_iterator It; + for(It it = stack_->rbegin(); it != stack_->rend(); ++it) + ss << (*it)->graphviz(); + ss << "}" << std::endl; + return ss.str(); + } + void backward() { for(auto&& v : *stack_) v->set_zero_adjoint(); diff --git a/src/node_operators.h b/src/node_operators.h index e5cf2110..8620a645 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -22,6 +22,13 @@ struct InputNode : public Node { void forward() {} void backward() {} + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"parallelogram\", label=\"input\", style=\"filled\", fillcolor=\"lawngreen\"]" << std::endl << std::endl; + return ss.str(); + }; + }; struct ConstantNode : public Node { @@ -35,6 +42,13 @@ struct ConstantNode : public Node { void forward() {} void backward() {} + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"diamond\", label=\"const\"]" << std::endl << std::endl; + return ss.str(); + }; + }; struct ParamNode : public Node { @@ -60,6 +74,13 @@ struct ParamNode : public Node { } } + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"hexagon\", label=\"param\", style=\"filled\", fillcolor=\"orangered\"]" << std::endl << std::endl; + return ss.str(); + }; + + private: std::function init_; bool initialized_; @@ -89,6 +110,14 @@ struct LogitNodeOp : public UnaryNodeOp { Element(_1 += _2 * _3 * (1 - _3), a_->grad(), adj_, val_); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct TanhNodeOp : public UnaryNodeOp { @@ -105,6 +134,14 @@ struct TanhNodeOp : public UnaryNodeOp { Element(_1 += _2 * (1 - _3 * _3), a_->grad(), adj_, val_); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; // @TODO, make this numerically safe(r): @@ -131,6 +168,14 @@ struct SoftmaxNodeOp : public UnaryNodeOp { SubtractMean(&result, val_); Element(_1 += _2 * _3, a_->grad(), val_, result); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct LogNodeOp : public UnaryNodeOp { @@ -146,6 +191,14 @@ struct LogNodeOp : public UnaryNodeOp { Element(_1 += _2 * 1.f / _3, a_->grad(), adj_, a_->val()); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct ExpNodeOp : public UnaryNodeOp { @@ -161,6 +214,14 @@ struct ExpNodeOp : public UnaryNodeOp { Element(_1 += _2 * Exp(_3), a_->grad(), adj_, a_->val()); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct NegNodeOp : public UnaryNodeOp { @@ -175,6 +236,14 @@ struct NegNodeOp : public UnaryNodeOp { void backward() { Element(_1 += -_2, a_->grad(), adj_); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; /******************************************************/ @@ -220,6 +289,15 @@ struct DotNodeOp : public BinaryNodeOp { Prod(a_->grad(), adj_, b_->val(), false, true, 1.0); Prod(b_->grad(), a_->val(), adj_, true, false, 1.0); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct PlusNodeOp : public BinaryNodeOp { @@ -238,6 +316,15 @@ struct PlusNodeOp : public BinaryNodeOp { Element(_1 += _2, b_->grad(), adj_); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct MinusNodeOp : public BinaryNodeOp { @@ -256,6 +343,15 @@ struct MinusNodeOp : public BinaryNodeOp { Element(_1 -= _2, b_->grad(), adj_); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct MultNodeOp : public BinaryNodeOp { @@ -274,6 +370,15 @@ struct MultNodeOp : public BinaryNodeOp { Element(_1 += _2 * _3, b_->grad(), adj_, a_->val()); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"•\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct DivNodeOp : public BinaryNodeOp { @@ -291,7 +396,16 @@ struct DivNodeOp : public BinaryNodeOp { a_->grad(), adj_, b_->val()); Element(_1 -= _2 * _3 / (_4 * _4), b_->grad(), adj_, a_->val(), b_->val()); - } + } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"÷\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; } diff --git a/src/test.cu b/src/test.cu index d27591be..f3bcef11 100644 --- a/src/test.cu +++ b/src/test.cu @@ -75,6 +75,8 @@ int main(int argc, char** argv) { Y[t] = Yt; } + std::cout << g.graphviz() << std::endl; + g.forward(batch_size); g.backward(); diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index cbd3e0a3..f71be921 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -47,7 +47,7 @@ ExpressionGraph build_graph() { int main(int argc, char** argv) { - cudaSetDevice(1); + cudaSetDevice(0); std::cerr << "Loading test set..."; std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE); @@ -62,6 +62,8 @@ int main(int argc, char** argv) { g["x"] = (xt << testImages); g["y"] = (yt << testLabels); + std::cout << g.graphviz() << std::endl; + g.forward(BATCH_SIZE); std::vector results;