mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
cleaned-up visualization
This commit is contained in:
parent
09ce7e6349
commit
5ed36889a6
@ -114,7 +114,7 @@ struct LogitNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -138,7 +138,7 @@ struct TanhNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -179,7 +179,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -208,7 +208,7 @@ struct ArgmaxNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -231,7 +231,7 @@ struct LogNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -254,7 +254,7 @@ struct ExpNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -276,7 +276,7 @@ struct NegNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -329,8 +329,8 @@ struct DotNodeOp : public BinaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -356,8 +356,8 @@ struct PlusNodeOp : public BinaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -383,8 +383,8 @@ struct MinusNodeOp : public BinaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -410,8 +410,8 @@ struct MultNodeOp : public BinaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"•\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
|
@ -6,9 +6,15 @@
|
||||
|
||||
namespace marian {
|
||||
|
||||
// @TODO: modify computation graph to group all paramters in single matrix object.
|
||||
// This will allow to perform a single large SGD update per batch. Currently there
|
||||
// are as many updates as different paramters.
|
||||
|
||||
// @TODO: Implement Element(...) with multiple functors for compacting of calls.
|
||||
|
||||
class Sgd {
|
||||
public:
|
||||
Sgd(float eta=0.001) : eta_(eta) {}
|
||||
Sgd(float eta=0.01) : eta_(eta) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
graph.backprop(batchSize);
|
||||
@ -25,7 +31,7 @@ class Sgd {
|
||||
// @TODO: Add serialization for historic gradients and parameters
|
||||
class Adagrad {
|
||||
public:
|
||||
Adagrad(float eta=0.001, float eps=10e-8)
|
||||
Adagrad(float eta=0.01, float eps=10e-8)
|
||||
: eta_(eta), eps_(eps) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
|
@ -27,7 +27,7 @@ int main(int argc, char** argv) {
|
||||
auto cost = named(-mean(sum(y * log(lr), axis=1), axis=0), "cost");
|
||||
std::cerr << "lr=" << lr.Debug() << std::endl;
|
||||
|
||||
Adagrad opt;
|
||||
Adam opt;
|
||||
opt(g, 300);
|
||||
|
||||
return 0;
|
||||
|
@ -62,7 +62,7 @@ int main(int argc, char** argv) {
|
||||
g["x"] = (xt << testImages);
|
||||
g["y"] = (yt << testLabels);
|
||||
|
||||
Adam opt;
|
||||
Adagrad opt;
|
||||
for(size_t j = 0; j < 20; ++j) {
|
||||
for(size_t i = 0; i < 60; ++i) {
|
||||
opt(g, BATCH_SIZE);
|
||||
@ -70,8 +70,6 @@ int main(int argc, char** argv) {
|
||||
std::cerr << g["cost"].val()[0] << std::endl;
|
||||
}
|
||||
|
||||
//std::cout << g.graphviz() << std::endl;
|
||||
|
||||
std::vector<float> results;
|
||||
results << g["probs"].val();
|
||||
|
||||
@ -80,8 +78,10 @@ int main(int argc, char** argv) {
|
||||
size_t correct = 0;
|
||||
size_t proposed = 0;
|
||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||
if (testLabels[i+j]) correct = j;
|
||||
if (results[i + j] > results[i + proposed]) proposed = j;
|
||||
if (testLabels[i+j])
|
||||
correct = j;
|
||||
if (results[i + j] > results[i + proposed])
|
||||
proposed = j;
|
||||
}
|
||||
acc += (correct == proposed);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user