mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Included embedding layer and graphviz part in the e-d.
This commit is contained in:
parent
a5630d2786
commit
1b27accaa0
@ -9,8 +9,9 @@ using namespace keywords;
|
||||
|
||||
const int input_size = 10;
|
||||
const int output_size = 15;
|
||||
const int batch_size = 25;
|
||||
const int embedding_size = 8;
|
||||
const int hidden_size = 5;
|
||||
const int batch_size = 25;
|
||||
const int num_inputs = 8;
|
||||
const int num_outputs = 6;
|
||||
|
||||
@ -20,34 +21,47 @@ ExpressionGraph build_graph(int cuda_device) {
|
||||
ExpressionGraph g(cuda_device);
|
||||
std::vector<Expr> X, Y, H, S;
|
||||
|
||||
// For the stop symbol.
|
||||
// We're including the stop symbol here.
|
||||
for (int t = 0; t <= num_inputs; ++t) {
|
||||
std::stringstream ss;
|
||||
ss << "X" << t;
|
||||
X.emplace_back(named(g.input(shape={batch_size, input_size}), ss.str()));
|
||||
}
|
||||
|
||||
// For the stop symbol.
|
||||
// We're including the stop symbol here.
|
||||
for (int t = 0; t <= num_outputs; ++t) {
|
||||
std::stringstream ss;
|
||||
ss << "Y" << t;
|
||||
Y.emplace_back(named(g.input(shape={batch_size, output_size}), ss.str()));
|
||||
}
|
||||
|
||||
Expr Wxh = named(g.param(shape={input_size, hidden_size}, init=uniform()), "Wxh");
|
||||
Expr Whh = named(g.param(shape={hidden_size, hidden_size}, init=uniform()), "Whh");
|
||||
Expr bh = named(g.param(shape={1, hidden_size}, init=uniform()), "bh");
|
||||
Expr h0 = named(g.param(shape={1, hidden_size}, init=uniform()), "h0");
|
||||
// Source embeddings.
|
||||
Expr E = named(g.param(shape={input_size, embedding_size},
|
||||
init=uniform()), "E");
|
||||
|
||||
// Source RNN parameters.
|
||||
Expr Wxh = named(g.param(shape={embedding_size, hidden_size},
|
||||
init=uniform()), "Wxh");
|
||||
Expr Whh = named(g.param(shape={hidden_size, hidden_size},
|
||||
init=uniform()), "Whh");
|
||||
Expr bh = named(g.param(shape={1, hidden_size},
|
||||
init=uniform()), "bh");
|
||||
Expr h0 = named(g.param(shape={1, hidden_size},
|
||||
init=uniform()), "h0");
|
||||
|
||||
std::cerr << "Building encoder RNN..." << std::endl;
|
||||
H.emplace_back(tanh(dot(X[0], Wxh) + dot(h0, Whh) + bh));
|
||||
H.emplace_back(tanh(dot(dot(X[0], E), Wxh) + dot(h0, Whh) + bh));
|
||||
for (int t = 1; t <= num_inputs; ++t) {
|
||||
H.emplace_back(tanh(dot(X[t], Wxh) + dot(H[t-1], Whh) + bh));
|
||||
H.emplace_back(tanh(dot(dot(X[t], E), Wxh) + dot(H[t-1], Whh) + bh));
|
||||
}
|
||||
|
||||
Expr Wxh_d = named(g.param(shape={output_size, hidden_size}, init=uniform()), "Wxh_d");
|
||||
Expr Whh_d = named(g.param(shape={hidden_size, hidden_size}, init=uniform()), "Whh_d");
|
||||
Expr bh_d = named(g.param(shape={1, hidden_size}, init=uniform()), "bh_d");
|
||||
// Target RNN parameters.
|
||||
Expr Wxh_d = named(g.param(shape={output_size, hidden_size},
|
||||
init=uniform()), "Wxh_d");
|
||||
Expr Whh_d = named(g.param(shape={hidden_size, hidden_size},
|
||||
init=uniform()), "Whh_d");
|
||||
Expr bh_d = named(g.param(shape={1, hidden_size},
|
||||
init=uniform()), "bh_d");
|
||||
|
||||
std::cerr << "Building decoder RNN..." << std::endl;
|
||||
auto h0_d = H[num_inputs];
|
||||
@ -56,12 +70,16 @@ ExpressionGraph build_graph(int cuda_device) {
|
||||
S.emplace_back(tanh(dot(Y[t], Wxh_d) + dot(S[t-1], Whh_d) + bh_d));
|
||||
}
|
||||
|
||||
Expr Why = named(g.param(shape={hidden_size, output_size}, init=uniform()), "Why");
|
||||
Expr by = named(g.param(shape={1, output_size}, init=uniform()), "by");
|
||||
// Output linear layer before softmax.
|
||||
Expr Why = named(g.param(shape={hidden_size, output_size},
|
||||
init=uniform()), "Why");
|
||||
Expr by = named(g.param(shape={1, output_size},
|
||||
init=uniform()), "by");
|
||||
|
||||
std::cerr << "Building output layer..." << std::endl;
|
||||
std::vector<Expr> Yp;
|
||||
|
||||
// Softmax layer and cost function.
|
||||
std::vector<Expr> Yp;
|
||||
Yp.emplace_back(named(softmax_fast(dot(h0_d, Why) + by), "pred"));
|
||||
Expr cross_entropy = sum(Y[0] * log(Yp[0]), axis=1);
|
||||
for (int t = 1; t <= num_outputs; ++t) {
|
||||
@ -75,8 +93,6 @@ ExpressionGraph build_graph(int cuda_device) {
|
||||
return g;
|
||||
}
|
||||
|
||||
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
#if 1
|
||||
std::cerr << "Loading the data... ";
|
||||
@ -102,12 +118,12 @@ int main(int argc, char** argv) {
|
||||
std::cerr << "Target vocabulary size: " << targetVocab.Size() << std::endl;
|
||||
#endif
|
||||
|
||||
// Build the encoder-decoder computation graph.
|
||||
ExpressionGraph g = build_graph(0);
|
||||
|
||||
// For the stop symbol.
|
||||
// Generate input data (include the stop symbol).
|
||||
for (int t = 0; t <= num_inputs; ++t) {
|
||||
Tensor Xt({batch_size, input_size});
|
||||
|
||||
float max = 1.;
|
||||
std::vector<float> values(batch_size * input_size);
|
||||
std::vector<float> classes(batch_size * output_size, 0.0);
|
||||
@ -117,16 +133,13 @@ int main(int argc, char** argv) {
|
||||
values[k] = max * (2.0*static_cast<float>(rand()) / RAND_MAX - 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
thrust::copy(values.begin(), values.end(), Xt.begin());
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "X" << t;
|
||||
if (!g.has_node(ss.str())) std::cerr << "No node " << ss.str() << "!!!" << std::endl;
|
||||
g[ss.str()] = Xt;
|
||||
|
||||
}
|
||||
|
||||
// Generate output data (include the stop symbol).
|
||||
for (int t = 0; t <= num_outputs; ++t) {
|
||||
Tensor Yt({batch_size, output_size});
|
||||
|
||||
@ -137,37 +150,31 @@ int main(int argc, char** argv) {
|
||||
classes[l + gold] = 1.0;
|
||||
l += output_size;
|
||||
}
|
||||
|
||||
thrust::copy(classes.begin(), classes.end(), Yt.begin());
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "Y" << t;
|
||||
if (!g.has_node(ss.str())) std::cerr << "No node " << ss.str() << "!!!" << std::endl;
|
||||
g[ss.str()] = Yt;
|
||||
}
|
||||
|
||||
std::cerr << "Graphviz step" << std::endl;
|
||||
std::cerr << "Printing the computation graph..." << std::endl;
|
||||
std::cout << g.graphviz() << std::endl;
|
||||
|
||||
std::cerr << "Forward step" << std::endl;
|
||||
std::cerr << "Running the forward step..." << std::endl;
|
||||
g.forward(batch_size);
|
||||
std::cerr << "Backward step" << std::endl;
|
||||
std::cerr << "Running the backward step..." << std::endl;
|
||||
g.backward();
|
||||
std::cerr << "Done" << std::endl;
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
std::cerr << g["cost"].val().Debug() << std::endl;
|
||||
|
||||
std::cerr << g["X0"].val().Debug() << std::endl;
|
||||
std::cerr << g["Y0"].val().Debug() << std::endl;
|
||||
|
||||
#if 1
|
||||
std::cerr << g["Whh"].grad().Debug() << std::endl;
|
||||
std::cerr << g["bh"].grad().Debug() << std::endl;
|
||||
std::cerr << g["Why"].grad().Debug() << std::endl;
|
||||
std::cerr << g["by"].grad().Debug() << std::endl;
|
||||
std::cerr << g["Wxh"].grad().Debug() << std::endl;
|
||||
std::cerr << g["h0"].grad().Debug() << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user