Training the encoder-decoder (one batch for now).

This commit is contained in:
Andre Martins 2016-09-17 09:35:23 +01:00
parent f6de1677e1
commit 04fb8734a2

View File

@ -1,7 +1,7 @@
#include "marian.h" #include "marian.h"
#include "mnist.h" #include "mnist.h"
#include "vocab.h" #include "vocab.h"
#include <assert.h> #include "optimizers.h"
using namespace marian; using namespace marian;
using namespace keywords; using namespace keywords;
@ -194,24 +194,15 @@ int main(int argc, char** argv) {
std::cerr << "Printing the computation graph..." << std::endl; std::cerr << "Printing the computation graph..." << std::endl;
std::cout << g.graphviz() << std::endl; std::cout << g.graphviz() << std::endl;
std::cerr << "Running the forward step..." << std::endl; std::cerr << "Training..." << std::endl;
g.forward(batch_size); Adam opt;
std::cerr << "Running the backward step..." << std::endl; int num_epochs = 20;
g.backward(); for(size_t epoch = 0; epoch < num_epochs; ++epoch) {
std::cerr << "Done." << std::endl; opt(g, batch_size); // Full batch for now.
std::cerr << "Epoch " << epoch << ": "
std::cerr << g["cost"].val().Debug() << std::endl; << "Loss = " << g["cost"].val()[0]
<< std::endl;
#if 0 }
std::cerr << g["X0"].val().Debug() << std::endl;
std::cerr << g["Y0"].val().Debug() << std::endl;
std::cerr << g["Whh"].grad().Debug() << std::endl;
std::cerr << g["bh"].grad().Debug() << std::endl;
std::cerr << g["Why"].grad().Debug() << std::endl;
std::cerr << g["by"].grad().Debug() << std::endl;
std::cerr << g["Wxh"].grad().Debug() << std::endl;
std::cerr << g["h0"].grad().Debug() << std::endl;
#endif
return 0; return 0;
} }