From 4bafa6c36010e33a3510f99820eca674db901551 Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Thu, 15 Sep 2016 00:55:24 +0200 Subject: [PATCH] toy training with initialized data --- src/graph_operators.h | 7 +------ src/validate_mnist.cu | 18 +++++++++++------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/graph_operators.h b/src/graph_operators.h index 2e016cac..32447e03 100644 --- a/src/graph_operators.h +++ b/src/graph_operators.h @@ -142,13 +142,8 @@ struct ArgmaxOp : public UnaryNodeOp { struct SoftmaxNodeOp : public UnaryNodeOp { template SoftmaxNodeOp(ChainPtr a, Args ...args) - : UnaryNodeOp(a, keywords::shape=newShape(a), + : UnaryNodeOp(a, keywords::shape=a->shape(), args...) { } - - Shape newShape(ChainPtr a) { - Shape shape = a->shape(); - return shape; - } void forward() { // B = softmax(A). diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index c31dd85a..a96b61a8 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -8,7 +8,7 @@ using namespace keywords; int main(int argc, char** argv) { - cudaSetDevice(1); + cudaSetDevice(0); const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; @@ -55,14 +55,18 @@ int main(int argc, char** argv) { y = yt << testLabels; graph.forward(BATCH_SIZE); - for(size_t i = 0; i < 1000; ++i) { - graph.backward(); - auto update_rule = _1 -= 0.1 * _2; - Element(update_rule, w.val(), w.grad()); - Element(update_rule, b.val(), b.grad()); + for (size_t j = 0; j < 10; ++j) { + for(size_t i = 0; i < 60; ++i) { + graph.backward(); - graph.forward(BATCH_SIZE); + auto update_rule = _1 -= 0.1 * _2; + Element(update_rule, w.val(), w.grad()); + Element(update_rule, b.val(), b.grad()); + + graph.forward(BATCH_SIZE); + } + std::cerr << "Epoch: " << j << std::endl; } auto results = predict.val();