mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-05 01:31:46 +03:00
Added mini-batch training to the encoder-decoder benchmark.
This commit is contained in:
parent
828a0db8bc
commit
26ffc8644e
@ -19,6 +19,9 @@
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
|
||||
#include <chrono>
|
||||
#include <boost/timer/timer.hpp>
|
||||
|
||||
#include "marian.h"
|
||||
#include "mnist.h"
|
||||
#include "vocab.h"
|
||||
@ -125,7 +128,7 @@ int main(int argc, char** argv) {
|
||||
|
||||
// Right now we're only reading the first few sentence pairs, and defining
|
||||
// that as the step size.
|
||||
int batch_size = 64;
|
||||
int batch_size = 20;
|
||||
int num_source_tokens = -1;
|
||||
int num_target_tokens = -1;
|
||||
std::vector<std::vector<size_t> > source_sentences, target_sentences;
|
||||
@ -142,7 +145,7 @@ int main(int argc, char** argv) {
|
||||
if (num_target_tokens < 0 || target_ids.size() > num_target_tokens) {
|
||||
num_target_tokens = target_ids.size();
|
||||
}
|
||||
if (source_sentences.size() == batch_size) break;
|
||||
if (source_sentences.size() == 1000) break;
|
||||
}
|
||||
std::cerr << "Done." << std::endl;
|
||||
std::cerr << source_sentences.size()
|
||||
@ -214,6 +217,63 @@ int main(int argc, char** argv) {
|
||||
std::cout << g.graphviz() << std::endl;
|
||||
|
||||
std::cerr << "Training..." << std::endl;
|
||||
|
||||
int num_training_examples = source_sentences.size();
|
||||
std::cerr << num_training_examples << " training examples." << std::endl;
|
||||
|
||||
boost::timer::cpu_timer total;
|
||||
Adam opt;
|
||||
int num_epochs = 20;
|
||||
for(int epoch = 1; epoch <= num_epochs; ++epoch) {
|
||||
boost::timer::cpu_timer timer;
|
||||
// TODO: shuffle the data.
|
||||
// shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);
|
||||
float cost = 0;
|
||||
for(int j = 0; j < num_training_examples / batch_size; j++) {
|
||||
// Attaching the data to the computation graph...
|
||||
// Convert the data to dense one-hot vectors.
|
||||
// TODO: make the graph handle sparse indices with a proper lookup layer.
|
||||
// TODO: use different sentence lengths for the batches.
|
||||
for (int t = 0; t < num_source_tokens; ++t) {
|
||||
Tensor Xt({batch_size, static_cast<int>(source_vocab.Size())});
|
||||
std::vector<float> values(batch_size * source_vocab.Size(), 0.0);
|
||||
int k = 0;
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
values[k + source_sentences[i + j*batch_size][t]] = 1.0;
|
||||
k += source_vocab.Size();
|
||||
}
|
||||
thrust::copy(values.begin(), values.end(), Xt.begin());
|
||||
// Attach this slice to the graph.
|
||||
std::stringstream ss;
|
||||
ss << "X" << t;
|
||||
g[ss.str()] = Xt;
|
||||
}
|
||||
|
||||
for (int t = 0; t < num_target_tokens; ++t) {
|
||||
Tensor Yt({batch_size, static_cast<int>(target_vocab.Size())});
|
||||
std::vector<float> values(batch_size * target_vocab.Size(), 0.0);
|
||||
int k = 0;
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
values[k + target_sentences[i + j*batch_size][t]] = 1.0;
|
||||
k += target_vocab.Size();
|
||||
}
|
||||
thrust::copy(values.begin(), values.end(), Yt.begin());
|
||||
// Attach this slice to the graph.
|
||||
std::stringstream ss;
|
||||
ss << "Y" << t;
|
||||
g[ss.str()] = Yt;
|
||||
}
|
||||
|
||||
opt(g, batch_size);
|
||||
cost += g["cost"].val()[0];
|
||||
}
|
||||
std::cerr << "Epoch: " << epoch << " - Cost: "
|
||||
<< cost / num_training_examples * batch_size
|
||||
<< " - " << timer.format(3, "%ws") << std::endl;
|
||||
}
|
||||
std::cerr << "Total: " << total.format(3, "%ws") << std::endl;
|
||||
|
||||
#if 0
|
||||
Adam opt;
|
||||
int num_epochs = 20;
|
||||
for(size_t epoch = 0; epoch < num_epochs; ++epoch) {
|
||||
@ -222,6 +282,7 @@ int main(int argc, char** argv) {
|
||||
<< "Loss = " << g["cost"].val()[0]
|
||||
<< std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user