cleaned-up visualization

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-17 08:12:45 +02:00
parent 09ce7e6349
commit 5ed36889a6
4 changed files with 29 additions and 23 deletions

View File

@ -114,7 +114,7 @@ struct LogitNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -138,7 +138,7 @@ struct TanhNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -179,7 +179,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -208,7 +208,7 @@ struct ArgmaxNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -231,7 +231,7 @@ struct LogNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -254,7 +254,7 @@ struct ExpNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -276,7 +276,7 @@ struct NegNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -329,8 +329,8 @@ struct DotNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -356,8 +356,8 @@ struct PlusNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -383,8 +383,8 @@ struct MinusNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -410,8 +410,8 @@ struct MultNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};

View File

@ -6,9 +6,15 @@
namespace marian {
// @TODO: modify computation graph to group all paramters in single matrix object.
// This will allow to perform a single large SGD update per batch. Currently there
// are as many updates as different paramters.
// @TODO: Implement Element(...) with multiple functors for compacting of calls.
class Sgd {
public:
Sgd(float eta=0.001) : eta_(eta) {}
Sgd(float eta=0.01) : eta_(eta) {}
void operator()(ExpressionGraph& graph, int batchSize) {
graph.backprop(batchSize);
@ -25,7 +31,7 @@ class Sgd {
// @TODO: Add serialization for historic gradients and parameters
class Adagrad {
public:
Adagrad(float eta=0.001, float eps=10e-8)
Adagrad(float eta=0.01, float eps=10e-8)
: eta_(eta), eps_(eps) {}
void operator()(ExpressionGraph& graph, int batchSize) {

View File

@ -27,7 +27,7 @@ int main(int argc, char** argv) {
auto cost = named(-mean(sum(y * log(lr), axis=1), axis=0), "cost");
std::cerr << "lr=" << lr.Debug() << std::endl;
Adagrad opt;
Adam opt;
opt(g, 300);
return 0;

View File

@ -62,7 +62,7 @@ int main(int argc, char** argv) {
g["x"] = (xt << testImages);
g["y"] = (yt << testLabels);
Adam opt;
Adagrad opt;
for(size_t j = 0; j < 20; ++j) {
for(size_t i = 0; i < 60; ++i) {
opt(g, BATCH_SIZE);
@ -70,8 +70,6 @@ int main(int argc, char** argv) {
std::cerr << g["cost"].val()[0] << std::endl;
}
//std::cout << g.graphviz() << std::endl;
std::vector<float> results;
results << g["probs"].val();
@ -80,8 +78,10 @@ int main(int argc, char** argv) {
size_t correct = 0;
size_t proposed = 0;
for (size_t j = 0; j < LABEL_SIZE; ++j) {
if (testLabels[i+j]) correct = j;
if (results[i + j] > results[i + proposed]) proposed = j;
if (testLabels[i+j])
correct = j;
if (results[i + j] > results[i + proposed])
proposed = j;
}
acc += (correct == proposed);
}