Some cleaning.

This commit is contained in:
Andre Martins 2016-09-17 02:00:03 +01:00
parent 22bbac3287
commit eb57df2a3e

View File

@ -1,4 +1,3 @@
#include "marian.h"
#include "mnist.h"
#include "vocab.h"
@ -96,7 +95,6 @@ ExpressionGraph build_graph(int source_vocabulary_size,
}
int main(int argc, char** argv) {
#if 1
std::cerr << "Loading the data... ";
Vocab source_vocab, target_vocab;
@ -193,66 +191,6 @@ int main(int argc, char** argv) {
g[ss.str()] = Yt;
}
#else
int source_vocabulary_size = 10;
int target_vocabulary_size = 15;
int embedding_size = 8;
int hidden_size = 5;
int batch_size = 25;
int num_source_tokens = 8;
int num_target_tokens = 6;
// Build the encoder-decoder computation graph.
ExpressionGraph g = build_graph(0, // cuda device.
source_vocabulary_size,
target_vocabulary_size,
embedding_size,
hidden_size,
num_source_tokens,
num_target_tokens);
int input_size = source_vocabulary_size;
int output_size = target_vocabulary_size;
int num_inputs = num_source_tokens;
int num_outputs = num_target_tokens;
// Generate input data (include the stop symbol).
for (int t = 0; t <= num_inputs; ++t) {
Tensor Xt({batch_size, input_size});
float max = 1.;
std::vector<float> values(batch_size * input_size);
std::vector<float> classes(batch_size * output_size, 0.0);
int k = 0;
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < input_size; ++j, ++k) {
values[k] = max * (2.0*static_cast<float>(rand()) / RAND_MAX - 1.0);
}
}
thrust::copy(values.begin(), values.end(), Xt.begin());
std::stringstream ss;
ss << "X" << t;
g[ss.str()] = Xt;
}
// Generate output data (include the stop symbol).
for (int t = 0; t <= num_outputs; ++t) {
Tensor Yt({batch_size, output_size});
std::vector<float> classes(batch_size * output_size, 0.0);
int l = 0;
for (int i = 0; i < batch_size; ++i) {
int gold = output_size * static_cast<float>(rand()) / RAND_MAX;
classes[l + gold] = 1.0;
l += output_size;
}
thrust::copy(classes.begin(), classes.end(), Yt.begin());
std::stringstream ss;
ss << "Y" << t;
g[ss.str()] = Yt;
}
#endif
std::cerr << "Printing the computation graph..." << std::endl;
std::cout << g.graphviz() << std::endl;