mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merge branch 'master' of https://github.com/emjotde/Marian
This commit is contained in:
commit
12decbeebd
@ -17,6 +17,8 @@ struct Chainable {
|
|||||||
virtual void set_zero_adjoint() { }
|
virtual void set_zero_adjoint() { }
|
||||||
|
|
||||||
virtual void allocate(size_t) = 0;
|
virtual void allocate(size_t) = 0;
|
||||||
|
virtual std::string graphviz() = 0;
|
||||||
|
|
||||||
|
|
||||||
virtual const Shape& shape() = 0;
|
virtual const Shape& shape() = 0;
|
||||||
virtual DataType &val() = 0;
|
virtual DataType &val() = 0;
|
||||||
|
@ -50,6 +50,18 @@ class ExpressionGraph {
|
|||||||
v->forward();
|
v->forward();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "digraph ExpressionGraph {" << std::endl;
|
||||||
|
ss << "rankdir=BT" << std::endl;
|
||||||
|
|
||||||
|
typedef typename ChainableStack::reverse_iterator It;
|
||||||
|
for(It it = stack_->rbegin(); it != stack_->rend(); ++it)
|
||||||
|
ss << (*it)->graphviz();
|
||||||
|
ss << "}" << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
void backward() {
|
void backward() {
|
||||||
for(auto&& v : *stack_)
|
for(auto&& v : *stack_)
|
||||||
v->set_zero_adjoint();
|
v->set_zero_adjoint();
|
||||||
|
@ -22,6 +22,13 @@ struct InputNode : public Node {
|
|||||||
|
|
||||||
void forward() {}
|
void forward() {}
|
||||||
void backward() {}
|
void backward() {}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"parallelogram\", label=\"input\", style=\"filled\", fillcolor=\"lawngreen\"]" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ConstantNode : public Node {
|
struct ConstantNode : public Node {
|
||||||
@ -35,6 +42,13 @@ struct ConstantNode : public Node {
|
|||||||
|
|
||||||
void forward() {}
|
void forward() {}
|
||||||
void backward() {}
|
void backward() {}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"diamond\", label=\"const\"]" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ParamNode : public Node {
|
struct ParamNode : public Node {
|
||||||
@ -60,6 +74,13 @@ struct ParamNode : public Node {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"hexagon\", label=\"param\", style=\"filled\", fillcolor=\"orangered\"]" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::function<void(Tensor)> init_;
|
std::function<void(Tensor)> init_;
|
||||||
bool initialized_;
|
bool initialized_;
|
||||||
@ -89,6 +110,14 @@ struct LogitNodeOp : public UnaryNodeOp {
|
|||||||
Element(_1 += _2 * _3 * (1 - _3),
|
Element(_1 += _2 * _3 * (1 - _3),
|
||||||
a_->grad(), adj_, val_);
|
a_->grad(), adj_, val_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TanhNodeOp : public UnaryNodeOp {
|
struct TanhNodeOp : public UnaryNodeOp {
|
||||||
@ -105,6 +134,14 @@ struct TanhNodeOp : public UnaryNodeOp {
|
|||||||
Element(_1 += _2 * (1 - _3 * _3),
|
Element(_1 += _2 * (1 - _3 * _3),
|
||||||
a_->grad(), adj_, val_);
|
a_->grad(), adj_, val_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// @TODO, make this numerically safe(r):
|
// @TODO, make this numerically safe(r):
|
||||||
@ -131,6 +168,14 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
|||||||
SubtractMean(&result, val_);
|
SubtractMean(&result, val_);
|
||||||
Element(_1 += _2 * _3, a_->grad(), val_, result);
|
Element(_1 += _2 * _3, a_->grad(), val_, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LogNodeOp : public UnaryNodeOp {
|
struct LogNodeOp : public UnaryNodeOp {
|
||||||
@ -146,6 +191,14 @@ struct LogNodeOp : public UnaryNodeOp {
|
|||||||
Element(_1 += _2 * 1.f / _3,
|
Element(_1 += _2 * 1.f / _3,
|
||||||
a_->grad(), adj_, a_->val());
|
a_->grad(), adj_, a_->val());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ExpNodeOp : public UnaryNodeOp {
|
struct ExpNodeOp : public UnaryNodeOp {
|
||||||
@ -161,6 +214,14 @@ struct ExpNodeOp : public UnaryNodeOp {
|
|||||||
Element(_1 += _2 * Exp(_3),
|
Element(_1 += _2 * Exp(_3),
|
||||||
a_->grad(), adj_, a_->val());
|
a_->grad(), adj_, a_->val());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct NegNodeOp : public UnaryNodeOp {
|
struct NegNodeOp : public UnaryNodeOp {
|
||||||
@ -175,6 +236,14 @@ struct NegNodeOp : public UnaryNodeOp {
|
|||||||
void backward() {
|
void backward() {
|
||||||
Element(_1 += -_2, a_->grad(), adj_);
|
Element(_1 += -_2, a_->grad(), adj_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/******************************************************/
|
/******************************************************/
|
||||||
@ -220,6 +289,15 @@ struct DotNodeOp : public BinaryNodeOp {
|
|||||||
Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
|
Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
|
||||||
Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
|
Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||||
|
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PlusNodeOp : public BinaryNodeOp {
|
struct PlusNodeOp : public BinaryNodeOp {
|
||||||
@ -238,6 +316,15 @@ struct PlusNodeOp : public BinaryNodeOp {
|
|||||||
Element(_1 += _2,
|
Element(_1 += _2,
|
||||||
b_->grad(), adj_);
|
b_->grad(), adj_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||||
|
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MinusNodeOp : public BinaryNodeOp {
|
struct MinusNodeOp : public BinaryNodeOp {
|
||||||
@ -256,6 +343,15 @@ struct MinusNodeOp : public BinaryNodeOp {
|
|||||||
Element(_1 -= _2,
|
Element(_1 -= _2,
|
||||||
b_->grad(), adj_);
|
b_->grad(), adj_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||||
|
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MultNodeOp : public BinaryNodeOp {
|
struct MultNodeOp : public BinaryNodeOp {
|
||||||
@ -274,6 +370,15 @@ struct MultNodeOp : public BinaryNodeOp {
|
|||||||
Element(_1 += _2 * _3,
|
Element(_1 += _2 * _3,
|
||||||
b_->grad(), adj_, a_->val());
|
b_->grad(), adj_, a_->val());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"•\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||||
|
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DivNodeOp : public BinaryNodeOp {
|
struct DivNodeOp : public BinaryNodeOp {
|
||||||
@ -291,7 +396,16 @@ struct DivNodeOp : public BinaryNodeOp {
|
|||||||
a_->grad(), adj_, b_->val());
|
a_->grad(), adj_, b_->val());
|
||||||
Element(_1 -= _2 * _3 / (_4 * _4),
|
Element(_1 -= _2 * _3 / (_4 * _4),
|
||||||
b_->grad(), adj_, a_->val(), b_->val());
|
b_->grad(), adj_, a_->val(), b_->val());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual std::string graphviz() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "\"" << this << "\" [shape=\"box\", label=\"÷\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||||
|
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||||
|
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -91,6 +91,8 @@ int main(int argc, char** argv) {
|
|||||||
Y[t] = Yt;
|
Y[t] = Yt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::cout << g.graphviz() << std::endl;
|
||||||
|
|
||||||
g.forward(batch_size);
|
g.forward(batch_size);
|
||||||
g.backward();
|
g.backward();
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ ExpressionGraph build_graph() {
|
|||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
|
|
||||||
cudaSetDevice(1);
|
cudaSetDevice(0);
|
||||||
|
|
||||||
std::cerr << "Loading test set...";
|
std::cerr << "Loading test set...";
|
||||||
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE);
|
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE);
|
||||||
@ -62,6 +62,8 @@ int main(int argc, char** argv) {
|
|||||||
g["x"] = (xt << testImages);
|
g["x"] = (xt << testImages);
|
||||||
g["y"] = (yt << testLabels);
|
g["y"] = (yt << testLabels);
|
||||||
|
|
||||||
|
std::cout << g.graphviz() << std::endl;
|
||||||
|
|
||||||
g.forward(BATCH_SIZE);
|
g.forward(BATCH_SIZE);
|
||||||
|
|
||||||
std::vector<float> results;
|
std::vector<float> results;
|
||||||
|
Loading…
Reference in New Issue
Block a user