diff --git a/src/node_operators.h b/src/node_operators.h index e7994c0a..2d675ae8 100644 --- a/src/node_operators.h +++ b/src/node_operators.h @@ -114,7 +114,7 @@ struct LogitNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -138,7 +138,7 @@ struct TanhNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -179,7 +179,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -208,7 +208,7 @@ struct ArgmaxNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -231,7 +231,7 @@ struct LogNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -254,7 +254,7 @@ struct ExpNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -276,7 +276,7 @@ struct NegNodeOp : public UnaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -329,8 +329,8 @@ struct DotNodeOp : public BinaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; - ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -356,8 +356,8 @@ struct PlusNodeOp : public BinaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; - ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -383,8 +383,8 @@ struct MinusNodeOp : public BinaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; - ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; @@ -410,8 +410,8 @@ struct MultNodeOp : public BinaryNodeOp { virtual std::string graphviz() { std::stringstream ss; ss << "\"" << this << "\" [shape=\"box\", label=\"•\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl; - ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl; - ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; + ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl; + ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl; return ss.str(); }; diff --git a/src/optimizers.h b/src/optimizers.h index a977d7f8..6bfe7861 100644 --- a/src/optimizers.h +++ b/src/optimizers.h @@ -6,9 +6,15 @@ namespace marian { +// @TODO: modify computation graph to group all paramters in single matrix object. +// This will allow to perform a single large SGD update per batch. Currently there +// are as many updates as different paramters. + +// @TODO: Implement Element(...) with multiple functors for compacting of calls. + class Sgd { public: - Sgd(float eta=0.001) : eta_(eta) {} + Sgd(float eta=0.01) : eta_(eta) {} void operator()(ExpressionGraph& graph, int batchSize) { graph.backprop(batchSize); @@ -25,7 +31,7 @@ class Sgd { // @TODO: Add serialization for historic gradients and parameters class Adagrad { public: - Adagrad(float eta=0.001, float eps=10e-8) + Adagrad(float eta=0.01, float eps=10e-8) : eta_(eta), eps_(eps) {} void operator()(ExpressionGraph& graph, int batchSize) { diff --git a/src/train_mnist.cu b/src/train_mnist.cu index 5b32cc9f..c4b932c0 100644 --- a/src/train_mnist.cu +++ b/src/train_mnist.cu @@ -27,7 +27,7 @@ int main(int argc, char** argv) { auto cost = named(-mean(sum(y * log(lr), axis=1), axis=0), "cost"); std::cerr << "lr=" << lr.Debug() << std::endl; - Adagrad opt; + Adam opt; opt(g, 300); return 0; diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 34364929..0f8e0482 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -62,7 +62,7 @@ int main(int argc, char** argv) { g["x"] = (xt << testImages); g["y"] = (yt << testLabels); - Adam opt; + Adagrad opt; for(size_t j = 0; j < 20; ++j) { for(size_t i = 0; i < 60; ++i) { opt(g, BATCH_SIZE); @@ -70,8 +70,6 @@ int main(int argc, char** argv) { std::cerr << g["cost"].val()[0] << std::endl; } - //std::cout << g.graphviz() << std::endl; - std::vector results; results << g["probs"].val(); @@ -80,8 +78,10 @@ int main(int argc, char** argv) { size_t correct = 0; size_t proposed = 0; for (size_t j = 0; j < LABEL_SIZE; ++j) { - if (testLabels[i+j]) correct = j; - if (results[i + j] > results[i + proposed]) proposed = j; + if (testLabels[i+j]) + correct = j; + if (results[i + j] > results[i + proposed]) + proposed = j; } acc += (correct == proposed); }