mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merge branch 'master' of github.com:emjotde/marian
This commit is contained in:
commit
0f72f927b7
@ -30,6 +30,15 @@
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
void random_permutation(int n, std::vector<size_t> *indices) {
|
||||
std::srand(std::time(0));
|
||||
indices->clear();
|
||||
for(size_t i = 0; i < n; ++i) {
|
||||
indices->push_back(i);
|
||||
}
|
||||
std::random_shuffle(indices->begin(), indices->end());
|
||||
}
|
||||
|
||||
ExpressionGraph build_graph(int source_vocabulary_size,
|
||||
int target_vocabulary_size,
|
||||
int embedding_size,
|
||||
@ -104,14 +113,17 @@ ExpressionGraph build_graph(int source_vocabulary_size,
|
||||
std::cerr << "Building output layer..." << std::endl;
|
||||
|
||||
// Softmax layer and cost function.
|
||||
std::vector<Expr> Yp;
|
||||
Yp.emplace_back(named(softmax(dot(h0_d, Why) + by), "pred"));
|
||||
Expr cross_entropy = sum(Y[0] * log(Yp[0]), axis=1);
|
||||
//std::vector<Expr> Yp;
|
||||
//Yp.emplace_back(named(softmax(dot(h0_d, Why) + by), "pred"));
|
||||
//Expr word_cost = sum(Y[0] * log(Yp[0]), axis=1);
|
||||
Expr word_cost = cross_entropy(dot(h0_d, Why) + by, Y[0]);
|
||||
for (int t = 1; t <= num_outputs; ++t) {
|
||||
Yp.emplace_back(named(softmax(dot(S[t-1], Why) + by), "pred"));
|
||||
cross_entropy = cross_entropy + sum(Y[t] * log(Yp[t]), axis=1);
|
||||
//Yp.emplace_back(named(softmax(dot(S[t-1], Why) + by), "pred"));
|
||||
//word_cost = word_cost + sum(Y[t] * log(Yp[t]), axis=1);
|
||||
word_cost = word_cost + cross_entropy(dot(S[t-1], Why) + by, Y[t]);
|
||||
}
|
||||
auto cost = named(-mean(cross_entropy, axis=0), "cost");
|
||||
//auto cost = named(-mean(word_cost, axis=0), "cost");
|
||||
auto cost = named(mean(word_cost, axis=0), "cost");
|
||||
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
@ -128,7 +140,7 @@ int main(int argc, char** argv) {
|
||||
|
||||
// Right now we're only reading the first few sentence pairs, and defining
|
||||
// that as the step size.
|
||||
int batch_size = 20;
|
||||
int batch_size = 100;
|
||||
int num_source_tokens = -1;
|
||||
int num_target_tokens = -1;
|
||||
std::vector<std::vector<size_t> > source_sentences, target_sentences;
|
||||
@ -145,7 +157,7 @@ int main(int argc, char** argv) {
|
||||
if (num_target_tokens < 0 || target_ids.size() > num_target_tokens) {
|
||||
num_target_tokens = target_ids.size();
|
||||
}
|
||||
if (source_sentences.size() == 1000) break;
|
||||
//if (source_sentences.size() == 1000) break;
|
||||
}
|
||||
std::cerr << "Done." << std::endl;
|
||||
std::cerr << source_sentences.size()
|
||||
@ -226,10 +238,14 @@ int main(int argc, char** argv) {
|
||||
int num_epochs = 20;
|
||||
for(int epoch = 1; epoch <= num_epochs; ++epoch) {
|
||||
boost::timer::cpu_timer timer;
|
||||
// TODO: shuffle the data.
|
||||
// TODO: shuffle the batches.
|
||||
// shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);
|
||||
std::vector<size_t> indices;
|
||||
int num_batches = num_training_examples / batch_size;
|
||||
random_permutation(num_batches, &indices);
|
||||
float cost = 0;
|
||||
for(int j = 0; j < num_training_examples / batch_size; j++) {
|
||||
for(int j = 0; j < num_batches; j++) {
|
||||
int b = indices[j]; // Batch index.
|
||||
// Attaching the data to the computation graph...
|
||||
// Convert the data to dense one-hot vectors.
|
||||
// TODO: make the graph handle sparse indices with a proper lookup layer.
|
||||
@ -239,7 +255,7 @@ int main(int argc, char** argv) {
|
||||
std::vector<float> values(batch_size * source_vocab.Size(), 0.0);
|
||||
int k = 0;
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
values[k + source_sentences[i + j*batch_size][t]] = 1.0;
|
||||
values[k + source_sentences[i + b*batch_size][t]] = 1.0;
|
||||
k += source_vocab.Size();
|
||||
}
|
||||
thrust::copy(values.begin(), values.end(), Xt.begin());
|
||||
@ -254,7 +270,7 @@ int main(int argc, char** argv) {
|
||||
std::vector<float> values(batch_size * target_vocab.Size(), 0.0);
|
||||
int k = 0;
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
values[k + target_sentences[i + j*batch_size][t]] = 1.0;
|
||||
values[k + target_sentences[i + b*batch_size][t]] = 1.0;
|
||||
k += target_vocab.Size();
|
||||
}
|
||||
thrust::copy(values.begin(), values.end(), Yt.begin());
|
||||
@ -273,16 +289,5 @@ int main(int argc, char** argv) {
|
||||
}
|
||||
std::cerr << "Total: " << total.format(3, "%ws") << std::endl;
|
||||
|
||||
#if 0
|
||||
Adam opt;
|
||||
int num_epochs = 20;
|
||||
for(size_t epoch = 0; epoch < num_epochs; ++epoch) {
|
||||
opt(g, batch_size); // Full batch for now.
|
||||
std::cerr << "Epoch " << epoch << ": "
|
||||
<< "Loss = " << g["cost"].val()[0]
|
||||
<< std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user