mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Initialization of RNN parameters.
This commit is contained in:
parent
9015577101
commit
4c060397c0
16
src/test.cu
16
src/test.cu
@ -23,10 +23,10 @@ int main(int argc, char** argv) {
|
||||
Y[t] = new Expr(input(shape={batch_size, output_size}));
|
||||
}
|
||||
|
||||
Expr Wxh = param(shape={input_size, hidden_size}, name="Wxh");
|
||||
Expr Whh = param(shape={hidden_size, hidden_size}, name="Whh");
|
||||
Expr bh = param(shape={1, hidden_size}, name="bh");
|
||||
Expr h0 = param(shape={1, hidden_size}, name="h0");
|
||||
Expr Wxh = param(shape={input_size, hidden_size}, init=uniform(), name="Wxh");
|
||||
Expr Whh = param(shape={hidden_size, hidden_size}, init=uniform(), name="Whh");
|
||||
Expr bh = param(shape={1, hidden_size}, init=uniform(), name="bh");
|
||||
Expr h0 = param(shape={1, hidden_size}, init=uniform(), name="h0");
|
||||
|
||||
std::cerr << "Building RNN..." << std::endl;
|
||||
H[0] = new Expr(tanh(dot(*X[0], Wxh) + dot(h0, Whh) + bh));
|
||||
@ -34,8 +34,8 @@ int main(int argc, char** argv) {
|
||||
H[t] = new Expr(tanh(dot(*X[t], Wxh) + dot(*H[t-1], Whh) + bh));
|
||||
}
|
||||
|
||||
Expr Why = param(shape={hidden_size, output_size}, name="Why");
|
||||
Expr by = param(shape={1, output_size}, name="by");
|
||||
Expr Why = param(shape={hidden_size, output_size}, init=uniform(), name="Why");
|
||||
Expr by = param(shape={1, output_size}, init=uniform(), name="by");
|
||||
|
||||
std::cerr << "Building output layer..." << std::endl;
|
||||
std::vector<Expr*> Yp(num_inputs);
|
||||
@ -80,6 +80,10 @@ int main(int argc, char** argv) {
|
||||
graph.backward();
|
||||
|
||||
std::cerr << graph.val().Debug() << std::endl;
|
||||
|
||||
std::cerr << X[0]->val().Debug() << std::endl;
|
||||
std::cerr << Y[0]->val().Debug() << std::endl;
|
||||
|
||||
std::cerr << Whh.grad().Debug() << std::endl;
|
||||
std::cerr << bh.grad().Debug() << std::endl;
|
||||
std::cerr << Why.grad().Debug() << std::endl;
|
||||
|
Loading…
Reference in New Issue
Block a user