diff --git a/src/validate_encoder_decoder.cu b/src/validate_encoder_decoder.cu index ec9951ec..15159ba6 100644 --- a/src/validate_encoder_decoder.cu +++ b/src/validate_encoder_decoder.cu @@ -1,7 +1,7 @@ #include "marian.h" #include "mnist.h" #include "vocab.h" -#include +#include "optimizers.h" using namespace marian; using namespace keywords; @@ -194,24 +194,15 @@ int main(int argc, char** argv) { std::cerr << "Printing the computation graph..." << std::endl; std::cout << g.graphviz() << std::endl; - std::cerr << "Running the forward step..." << std::endl; - g.forward(batch_size); - std::cerr << "Running the backward step..." << std::endl; - g.backward(); - std::cerr << "Done." << std::endl; - - std::cerr << g["cost"].val().Debug() << std::endl; - -#if 0 - std::cerr << g["X0"].val().Debug() << std::endl; - std::cerr << g["Y0"].val().Debug() << std::endl; - std::cerr << g["Whh"].grad().Debug() << std::endl; - std::cerr << g["bh"].grad().Debug() << std::endl; - std::cerr << g["Why"].grad().Debug() << std::endl; - std::cerr << g["by"].grad().Debug() << std::endl; - std::cerr << g["Wxh"].grad().Debug() << std::endl; - std::cerr << g["h0"].grad().Debug() << std::endl; -#endif + std::cerr << "Training..." << std::endl; + Adam opt; + int num_epochs = 20; + for(size_t epoch = 0; epoch < num_epochs; ++epoch) { + opt(g, batch_size); // Full batch for now. + std::cerr << "Epoch " << epoch << ": " + << "Loss = " << g["cost"].val()[0] + << std::endl; + } return 0; }