diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e58473cf..0694509c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -36,11 +36,16 @@ cuda_add_executable( validate_mnist_batch validate_mnist_batch.cu ) +cuda_add_executable( + validate_encoder_decoder + validate_encoder_decoder.cu +) target_link_libraries(validate_mnist marian_lib) target_link_libraries(validate_mnist_batch marian_lib) +target_link_libraries(validate_encoder_decoder marian_lib) -foreach(exec marian train_mnist validate_mnist validate_mnist_batch ) +foreach(exec marian train_mnist validate_mnist validate_mnist_batch validate_encoder_decoder) target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn) cuda_add_cublas_to_target(${exec}) set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}") diff --git a/src/chainable.h b/src/chainable.h index 9fe6d208..885efdbd 100644 --- a/src/chainable.h +++ b/src/chainable.h @@ -17,6 +17,8 @@ struct Chainable { virtual void set_zero_adjoint() { } virtual void allocate(size_t) = 0; + virtual std::string graphviz() = 0; + virtual const Shape& shape() = 0; virtual DataType &val() = 0; diff --git a/src/expression_graph.h b/src/expression_graph.h index 8b7d74f4..a092fcb1 100644 --- a/src/expression_graph.h +++ b/src/expression_graph.h @@ -48,6 +48,18 @@ class ExpressionGraph { v->forward(); } + std::string graphviz() { + std::stringstream ss; + ss << "digraph ExpressionGraph {" << std::endl; + ss << "rankdir=BT" << std::endl; + + typedef typename ChainableStack::reverse_iterator It; + for(It it = stack_->rbegin(); it != stack_->rend(); ++it) + ss << (*it)->graphviz(); + ss << "}" << std::endl; + return ss.str(); + } + void backward() { for(auto&& v : *stack_) v->set_zero_adjoint(); diff --git a/src/node_operators.h b/src/node_operators.h index e5cf2110..8620a645 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -22,6 +22,13 @@ struct InputNode : public Node { void forward() {} void backward() {} + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"parallelogram\", label=\"input\", style=\"filled\", fillcolor=\"lawngreen\"]" << std::endl << std::endl; + return ss.str(); + }; + }; struct ConstantNode : public Node { @@ -35,6 +42,13 @@ struct ConstantNode : public Node { void forward() {} void backward() {} + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"diamond\", label=\"const\"]" << std::endl << std::endl; + return ss.str(); + }; + }; struct ParamNode : public Node { @@ -60,6 +74,13 @@ struct ParamNode : public Node { } } + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"hexagon\", label=\"param\", style=\"filled\", fillcolor=\"orangered\"]" << std::endl << std::endl; + return ss.str(); + }; + + private: std::function init_; bool initialized_; @@ -89,6 +110,14 @@ struct LogitNodeOp : public UnaryNodeOp { Element(_1 += _2 * _3 * (1 - _3), a_->grad(), adj_, val_); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct TanhNodeOp : public UnaryNodeOp { @@ -105,6 +134,14 @@ struct TanhNodeOp : public UnaryNodeOp { Element(_1 += _2 * (1 - _3 * _3), a_->grad(), adj_, val_); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; // @TODO, make this numerically safe(r): @@ -131,6 +168,14 @@ struct SoftmaxNodeOp : public UnaryNodeOp { SubtractMean(&result, val_); Element(_1 += _2 * _3, a_->grad(), val_, result); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct LogNodeOp : public UnaryNodeOp { @@ -146,6 +191,14 @@ struct LogNodeOp : public UnaryNodeOp { Element(_1 += _2 * 1.f / _3, a_->grad(), adj_, a_->val()); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct ExpNodeOp : public UnaryNodeOp { @@ -161,6 +214,14 @@ struct ExpNodeOp : public UnaryNodeOp { Element(_1 += _2 * Exp(_3), a_->grad(), adj_, a_->val()); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct NegNodeOp : public UnaryNodeOp { @@ -175,6 +236,14 @@ struct NegNodeOp : public UnaryNodeOp { void backward() { Element(_1 += -_2, a_->grad(), adj_); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; /******************************************************/ @@ -220,6 +289,15 @@ struct DotNodeOp : public BinaryNodeOp { Prod(a_->grad(), adj_, b_->val(), false, true, 1.0); Prod(b_->grad(), a_->val(), adj_, true, false, 1.0); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct PlusNodeOp : public BinaryNodeOp { @@ -238,6 +316,15 @@ struct PlusNodeOp : public BinaryNodeOp { Element(_1 += _2, b_->grad(), adj_); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct MinusNodeOp : public BinaryNodeOp { @@ -256,6 +343,15 @@ struct MinusNodeOp : public BinaryNodeOp { Element(_1 -= _2, b_->grad(), adj_); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct MultNodeOp : public BinaryNodeOp { @@ -274,6 +370,15 @@ struct MultNodeOp : public BinaryNodeOp { Element(_1 += _2 * _3, b_->grad(), adj_, a_->val()); } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"•\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; struct DivNodeOp : public BinaryNodeOp { @@ -291,7 +396,16 @@ struct DivNodeOp : public BinaryNodeOp { a_->grad(), adj_, b_->val()); Element(_1 -= _2 * _3 / (_4 * _4), b_->grad(), adj_, a_->val(), b_->val()); - } + } + + virtual std::string graphviz() { + std::stringstream ss; + ss << "\"" << this << "\" [shape=\"box\", label=\"÷\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + return ss.str(); + }; + }; } diff --git a/src/test.cu b/src/test.cu index ba798715..8d7073f4 100644 --- a/src/test.cu +++ b/src/test.cu @@ -90,6 +90,8 @@ int main(int argc, char** argv) { Y[t] = Yt; } + std::cout << g.graphviz() << std::endl; + g.forward(batch_size); g.backward(); diff --git a/src/validate_encoder_decoder.cu b/src/validate_encoder_decoder.cu new file mode 100644 index 00000000..b053c0ea --- /dev/null +++ b/src/validate_encoder_decoder.cu @@ -0,0 +1,142 @@ + +#include "marian.h" +#include "mnist.h" + +#if 0 +ExpressionGraph build_graph() { + std::cerr << "Loading model params..."; +} + + // read parallel corpus from file + std::fstream sourceFile("../examples/mt/dev/newstest2013.de"); + std::fstream targetFile("../examples/mt/dev/newstest2013.en"); + + std::string sourceLine, targetLine; + while (getline(sourceFile, sourceLine)) { + getline(targetFile, targetLine); + std::vector sourceIds = sourceVocab.ProcessSentence(sourceLine); + std::vector targetIds = sourceVocab.ProcessSentence(targetLine); + } +#endif + + +int main(int argc, char** argv) { + + using namespace marian; + using namespace keywords; + + int input_size = 10; + int output_size = 15; + int batch_size = 25; + int hidden_size = 5; + int num_inputs = 8; + int num_outputs = 6; + + ExpressionGraph g(0); + std::vector X(num_inputs+1); // For the stop symbol. + std::vector Y(num_outputs); + std::vector H(num_inputs+1); // For the stop symbol. + std::vector S(num_outputs); + + // For the stop symbol. + for (int t = 0; t <= num_inputs; ++t) { + X[t] = new Expr(g.input(shape={batch_size, input_size})); + } + + // For the stop symbol. + for (int t = 0; t <= num_outputs; ++t) { + Y[t] = new Expr(g.input(shape={batch_size, output_size})); + } + + Expr Wxh = g.param(shape={input_size, hidden_size}, init=uniform(), name="Wxh"); + Expr Whh = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh"); + Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh"); + Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0"); + + std::cerr << "Building encoder RNN..." << std::endl; + H[0] = new Expr(tanh(dot(*X[0], Wxh) + dot(h0, Whh) + bh)); + for (int t = 1; t <= num_inputs; ++t) { + H[t] = new Expr(tanh(dot(*X[t], Wxh) + dot(*H[t-1], Whh) + bh)); + } + + Expr Wxh_d = g.param(shape={output_size, hidden_size}, init=uniform(), name="Wxh_d"); + Expr Whh_d = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh_d"); + Expr bh_d = g.param(shape={1, hidden_size}, init=uniform(), name="bh_d"); + + std::cerr << "Building decoder RNN..." << std::endl; + auto h0_d = *H[num_inputs]; + S[0] = new Expr(tanh(dot(*Y[0], Wxh_d) + dot(h0_d, Whh_d) + bh_d)); + for (int t = 1; t < num_outputs; ++t) { + S[t] = new Expr(tanh(dot(*Y[t], Wxh_d) + dot(*S[t-1], Whh_d) + bh_d)); + } + + Expr Why = g.param(shape={hidden_size, output_size}, init=uniform(), name="Why"); + Expr by = g.param(shape={1, output_size}, init=uniform(), name="by"); + + std::cerr << "Building output layer..." << std::endl; + std::vector Yp(num_outputs+1); // For the stop symbol. + + Expr* cross_entropy = NULL; + for (int t = 0; t <= num_outputs; ++t) { + if (t == 0) { + Yp[t] = new Expr(named(softmax_fast(dot(h0_d, Why) + by), "pred")); + cross_entropy = new Expr(sum(*Y[t] * log(*Yp[t]), axis=1)); + } else { + Yp[t] = new Expr(named(softmax_fast(dot(*S[t-1], Why) + by), "pred")); + *cross_entropy = *cross_entropy + sum(*Y[t] * log(*Yp[t]), axis=1); + } + } + auto graph = -mean(*cross_entropy, axis=0, name="cost"); + + // For the stop symbol. + for (int t = 0; t <= num_inputs; ++t) { + Tensor Xt({batch_size, input_size}); + + float max = 1.; + std::vector values(batch_size * input_size); + std::vector classes(batch_size * output_size, 0.0); + int k = 0; + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < input_size; ++j, ++k) { + values[k] = max * (2.0*static_cast(rand()) / RAND_MAX - 1.0); + } + } + + thrust::copy(values.begin(), values.end(), Xt.begin()); + + *X[t] = Xt; + } + + for (int t = 0; t < num_outputs; ++t) { + Tensor Yt({batch_size, output_size}); + + std::vector classes(batch_size * output_size, 0.0); + int l = 0; + for (int i = 0; i < batch_size; ++i) { + int gold = output_size * static_cast(rand()) / RAND_MAX; + classes[l + gold] = 1.0; + l += output_size; + } + + thrust::copy(classes.begin(), classes.end(), Yt.begin()); + + *Y[t] = Yt; + } + + g.forward(batch_size); + g.backward(); + + std::cerr << graph.val().Debug() << std::endl; + + std::cerr << X[0]->val().Debug() << std::endl; + std::cerr << Y[0]->val().Debug() << std::endl; + + std::cerr << Whh.grad().Debug() << std::endl; + std::cerr << bh.grad().Debug() << std::endl; + std::cerr << Why.grad().Debug() << std::endl; + std::cerr << by.grad().Debug() << std::endl; + std::cerr << Wxh.grad().Debug() << std::endl; + std::cerr << h0.grad().Debug() << std::endl; + + return 0; +} diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 36f0135a..449fc193 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -10,7 +10,7 @@ const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; int BATCH_SIZE = 10000; -ExpressionGraph build_graph(int cudaDevice) { +ExpressionGraph build_graph() { std::cerr << "Loading model params..."; NpzConverter converter("../scripts/test_model_single/model.npz"); @@ -22,7 +22,7 @@ ExpressionGraph build_graph(int cudaDevice) { std::cerr << "Building model..."; - ExpressionGraph g(cudaDevice); + ExpressionGraph g(0); auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); @@ -46,12 +46,15 @@ ExpressionGraph build_graph(int cudaDevice) { } int main(int argc, char** argv) { + + cudaSetDevice(0); + std::cerr << "Loading test set..."; std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE); std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE); std::cerr << "Done." << std::endl; - ExpressionGraph g = build_graph(1); + ExpressionGraph g = build_graph(); Tensor xt({BATCH_SIZE, IMAGE_SIZE}); Tensor yt({BATCH_SIZE, LABEL_SIZE}); @@ -59,6 +62,8 @@ int main(int argc, char** argv) { g["x"] = (xt << testImages); g["y"] = (yt << testLabels); + std::cout << g.graphviz() << std::endl; + g.forward(BATCH_SIZE); std::vector results;