mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Separating graph from data.
This commit is contained in:
parent
12decbeebd
commit
c7a1542b02
@ -2,11 +2,78 @@
|
|||||||
#include "marian.h"
|
#include "marian.h"
|
||||||
#include "mnist.h"
|
#include "mnist.h"
|
||||||
|
|
||||||
#if 0
|
using namespace marian;
|
||||||
ExpressionGraph build_graph() {
|
using namespace keywords;
|
||||||
std::cerr << "Loading model params...";
|
|
||||||
|
const int input_size = 10;
|
||||||
|
const int output_size = 15;
|
||||||
|
const int batch_size = 25;
|
||||||
|
const int hidden_size = 5;
|
||||||
|
const int num_inputs = 8;
|
||||||
|
const int num_outputs = 6;
|
||||||
|
|
||||||
|
ExpressionGraph build_graph(int cuda_device) {
|
||||||
|
std::cerr << "Building computation graph..." << std::endl;
|
||||||
|
|
||||||
|
ExpressionGraph g(cuda_device);
|
||||||
|
std::vector<Expr> X, Y, H, S;
|
||||||
|
|
||||||
|
// For the stop symbol.
|
||||||
|
for (int t = 0; t <= num_inputs; ++t) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "X" << t;
|
||||||
|
X.emplace_back(named(g.input(shape={batch_size, input_size}), ss.str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// For the stop symbol.
|
||||||
|
for (int t = 0; t <= num_outputs; ++t) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "Y" << t;
|
||||||
|
Y.emplace_back(named(g.input(shape={batch_size, output_size}), ss.str()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr Wxh = g.param(shape={input_size, hidden_size}, init=uniform(), name="Wxh");
|
||||||
|
Expr Whh = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh");
|
||||||
|
Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh");
|
||||||
|
Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0");
|
||||||
|
|
||||||
|
std::cerr << "Building encoder RNN..." << std::endl;
|
||||||
|
H.emplace_back(tanh(dot(X[0], Wxh) + dot(h0, Whh) + bh));
|
||||||
|
for (int t = 1; t <= num_inputs; ++t) {
|
||||||
|
H.emplace_back(tanh(dot(X[t], Wxh) + dot(H[t-1], Whh) + bh));
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr Wxh_d = g.param(shape={output_size, hidden_size}, init=uniform(), name="Wxh_d");
|
||||||
|
Expr Whh_d = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh_d");
|
||||||
|
Expr bh_d = g.param(shape={1, hidden_size}, init=uniform(), name="bh_d");
|
||||||
|
|
||||||
|
std::cerr << "Building decoder RNN..." << std::endl;
|
||||||
|
auto h0_d = H[num_inputs];
|
||||||
|
S.emplace_back(tanh(dot(Y[0], Wxh_d) + dot(h0_d, Whh_d) + bh_d));
|
||||||
|
for (int t = 1; t < num_outputs; ++t) {
|
||||||
|
S.emplace_back(tanh(dot(Y[t], Wxh_d) + dot(S[t-1], Whh_d) + bh_d));
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr Why = g.param(shape={hidden_size, output_size}, init=uniform(), name="Why");
|
||||||
|
Expr by = g.param(shape={1, output_size}, init=uniform(), name="by");
|
||||||
|
|
||||||
|
std::cerr << "Building output layer..." << std::endl;
|
||||||
|
std::vector<Expr> Yp;
|
||||||
|
|
||||||
|
Yp.emplace_back(named(softmax_fast(dot(h0_d, Why) + by), "pred"));
|
||||||
|
Expr cross_entropy = sum(Y[0] * log(Yp[0]), axis=1);
|
||||||
|
for (int t = 1; t <= num_outputs; ++t) {
|
||||||
|
Yp.emplace_back(named(softmax_fast(dot(S[t-1], Why) + by), "pred"));
|
||||||
|
cross_entropy = cross_entropy + sum(Y[t] * log(Yp[t]), axis=1);
|
||||||
|
}
|
||||||
|
auto graph = -mean(cross_entropy, axis=0, name="cost");
|
||||||
|
|
||||||
|
std::cerr << "Done." << std::endl;
|
||||||
|
|
||||||
|
return g;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
// read parallel corpus from file
|
// read parallel corpus from file
|
||||||
std::fstream sourceFile("../examples/mt/dev/newstest2013.de");
|
std::fstream sourceFile("../examples/mt/dev/newstest2013.de");
|
||||||
std::fstream targetFile("../examples/mt/dev/newstest2013.en");
|
std::fstream targetFile("../examples/mt/dev/newstest2013.en");
|
||||||
@ -21,73 +88,8 @@ ExpressionGraph build_graph() {
|
|||||||
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
cudaSetDevice(0);
|
|
||||||
|
|
||||||
using namespace marian;
|
ExpressionGraph g = build_graph(0);
|
||||||
using namespace keywords;
|
|
||||||
|
|
||||||
int input_size = 10;
|
|
||||||
int output_size = 15;
|
|
||||||
int batch_size = 25;
|
|
||||||
int hidden_size = 5;
|
|
||||||
int num_inputs = 8;
|
|
||||||
int num_outputs = 6;
|
|
||||||
|
|
||||||
ExpressionGraph g;
|
|
||||||
std::vector<Expr*> X(num_inputs+1); // For the stop symbol.
|
|
||||||
std::vector<Expr*> Y(num_outputs);
|
|
||||||
std::vector<Expr*> H(num_inputs+1); // For the stop symbol.
|
|
||||||
std::vector<Expr*> S(num_outputs);
|
|
||||||
|
|
||||||
// For the stop symbol.
|
|
||||||
for (int t = 0; t <= num_inputs; ++t) {
|
|
||||||
X[t] = new Expr(g.input(shape={batch_size, input_size}));
|
|
||||||
}
|
|
||||||
|
|
||||||
// For the stop symbol.
|
|
||||||
for (int t = 0; t <= num_outputs; ++t) {
|
|
||||||
Y[t] = new Expr(g.input(shape={batch_size, output_size}));
|
|
||||||
}
|
|
||||||
|
|
||||||
Expr Wxh = g.param(shape={input_size, hidden_size}, init=uniform(), name="Wxh");
|
|
||||||
Expr Whh = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh");
|
|
||||||
Expr bh = g.param(shape={1, hidden_size}, init=uniform(), name="bh");
|
|
||||||
Expr h0 = g.param(shape={1, hidden_size}, init=uniform(), name="h0");
|
|
||||||
|
|
||||||
std::cerr << "Building encoder RNN..." << std::endl;
|
|
||||||
H[0] = new Expr(tanh(dot(*X[0], Wxh) + dot(h0, Whh) + bh));
|
|
||||||
for (int t = 1; t <= num_inputs; ++t) {
|
|
||||||
H[t] = new Expr(tanh(dot(*X[t], Wxh) + dot(*H[t-1], Whh) + bh));
|
|
||||||
}
|
|
||||||
|
|
||||||
Expr Wxh_d = g.param(shape={output_size, hidden_size}, init=uniform(), name="Wxh_d");
|
|
||||||
Expr Whh_d = g.param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh_d");
|
|
||||||
Expr bh_d = g.param(shape={1, hidden_size}, init=uniform(), name="bh_d");
|
|
||||||
|
|
||||||
std::cerr << "Building decoder RNN..." << std::endl;
|
|
||||||
auto h0_d = *H[num_inputs];
|
|
||||||
S[0] = new Expr(tanh(dot(*Y[0], Wxh_d) + dot(h0_d, Whh_d) + bh_d));
|
|
||||||
for (int t = 1; t < num_outputs; ++t) {
|
|
||||||
S[t] = new Expr(tanh(dot(*Y[t], Wxh_d) + dot(*S[t-1], Whh_d) + bh_d));
|
|
||||||
}
|
|
||||||
|
|
||||||
Expr Why = g.param(shape={hidden_size, output_size}, init=uniform(), name="Why");
|
|
||||||
Expr by = g.param(shape={1, output_size}, init=uniform(), name="by");
|
|
||||||
|
|
||||||
std::cerr << "Building output layer..." << std::endl;
|
|
||||||
std::vector<Expr*> Yp(num_outputs+1); // For the stop symbol.
|
|
||||||
|
|
||||||
Expr* cross_entropy = NULL;
|
|
||||||
for (int t = 0; t <= num_outputs; ++t) {
|
|
||||||
if (t == 0) {
|
|
||||||
Yp[t] = new Expr(named(softmax_fast(dot(h0_d, Why) + by), "pred"));
|
|
||||||
cross_entropy = new Expr(sum(*Y[t] * log(*Yp[t]), axis=1));
|
|
||||||
} else {
|
|
||||||
Yp[t] = new Expr(named(softmax_fast(dot(*S[t-1], Why) + by), "pred"));
|
|
||||||
*cross_entropy = *cross_entropy + sum(*Y[t] * log(*Yp[t]), axis=1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto graph = -mean(*cross_entropy, axis=0, name="cost");
|
|
||||||
|
|
||||||
// For the stop symbol.
|
// For the stop symbol.
|
||||||
for (int t = 0; t <= num_inputs; ++t) {
|
for (int t = 0; t <= num_inputs; ++t) {
|
||||||
@ -105,10 +107,13 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
thrust::copy(values.begin(), values.end(), Xt.begin());
|
thrust::copy(values.begin(), values.end(), Xt.begin());
|
||||||
|
|
||||||
*X[t] = Xt;
|
std::stringstream ss;
|
||||||
|
ss << "X" << t;
|
||||||
|
g[ss.str()] = Xt;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int t = 0; t < num_outputs; ++t) {
|
for (int t = 0; t <= num_outputs; ++t) {
|
||||||
Tensor Yt({batch_size, output_size});
|
Tensor Yt({batch_size, output_size});
|
||||||
|
|
||||||
std::vector<float> classes(batch_size * output_size, 0.0);
|
std::vector<float> classes(batch_size * output_size, 0.0);
|
||||||
@ -121,23 +126,33 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
thrust::copy(classes.begin(), classes.end(), Yt.begin());
|
thrust::copy(classes.begin(), classes.end(), Yt.begin());
|
||||||
|
|
||||||
*Y[t] = Yt;
|
std::stringstream ss;
|
||||||
|
ss << "Y" << t;
|
||||||
|
g[ss.str()] = Yt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::cerr << "Graphviz step" << std::endl;
|
||||||
|
std::cout << g.graphviz() << std::endl;
|
||||||
|
|
||||||
|
std::cerr << "Forward step" << std::endl;
|
||||||
g.forward(batch_size);
|
g.forward(batch_size);
|
||||||
|
std::cerr << "Backward step" << std::endl;
|
||||||
g.backward();
|
g.backward();
|
||||||
|
std::cerr << "Done" << std::endl;
|
||||||
|
|
||||||
std::cerr << graph.val().Debug() << std::endl;
|
std::cerr << g["graph"].val().Debug() << std::endl;
|
||||||
|
|
||||||
std::cerr << X[0]->val().Debug() << std::endl;
|
std::cerr << g["X0"].val().Debug() << std::endl;
|
||||||
std::cerr << Y[0]->val().Debug() << std::endl;
|
std::cerr << g["Y0"].val().Debug() << std::endl;
|
||||||
|
|
||||||
|
#if 0
|
||||||
std::cerr << Whh.grad().Debug() << std::endl;
|
std::cerr << Whh.grad().Debug() << std::endl;
|
||||||
std::cerr << bh.grad().Debug() << std::endl;
|
std::cerr << bh.grad().Debug() << std::endl;
|
||||||
std::cerr << Why.grad().Debug() << std::endl;
|
std::cerr << Why.grad().Debug() << std::endl;
|
||||||
std::cerr << by.grad().Debug() << std::endl;
|
std::cerr << by.grad().Debug() << std::endl;
|
||||||
std::cerr << Wxh.grad().Debug() << std::endl;
|
std::cerr << Wxh.grad().Debug() << std::endl;
|
||||||
std::cerr << h0.grad().Debug() << std::endl;
|
std::cerr << h0.grad().Debug() << std::endl;
|
||||||
|
#endif
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user