mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Training the encoder-decoder (one batch for now).
This commit is contained in:
parent
f6de1677e1
commit
04fb8734a2
@ -1,7 +1,7 @@
|
||||
#include "marian.h"
|
||||
#include "mnist.h"
|
||||
#include "vocab.h"
|
||||
#include <assert.h>
|
||||
#include "optimizers.h"
|
||||
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
@ -194,24 +194,15 @@ int main(int argc, char** argv) {
|
||||
std::cerr << "Printing the computation graph..." << std::endl;
|
||||
std::cout << g.graphviz() << std::endl;
|
||||
|
||||
std::cerr << "Running the forward step..." << std::endl;
|
||||
g.forward(batch_size);
|
||||
std::cerr << "Running the backward step..." << std::endl;
|
||||
g.backward();
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
std::cerr << g["cost"].val().Debug() << std::endl;
|
||||
|
||||
#if 0
|
||||
std::cerr << g["X0"].val().Debug() << std::endl;
|
||||
std::cerr << g["Y0"].val().Debug() << std::endl;
|
||||
std::cerr << g["Whh"].grad().Debug() << std::endl;
|
||||
std::cerr << g["bh"].grad().Debug() << std::endl;
|
||||
std::cerr << g["Why"].grad().Debug() << std::endl;
|
||||
std::cerr << g["by"].grad().Debug() << std::endl;
|
||||
std::cerr << g["Wxh"].grad().Debug() << std::endl;
|
||||
std::cerr << g["h0"].grad().Debug() << std::endl;
|
||||
#endif
|
||||
std::cerr << "Training..." << std::endl;
|
||||
Adam opt;
|
||||
int num_epochs = 20;
|
||||
for(size_t epoch = 0; epoch < num_epochs; ++epoch) {
|
||||
opt(g, batch_size); // Full batch for now.
|
||||
std::cerr << "Epoch " << epoch << ": "
|
||||
<< "Loss = " << g["cost"].val()[0]
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user