From a573eecf5c9fb0eb83b7516ee03294ca6c3c682e Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Wed, 14 Sep 2016 14:33:30 +0100 Subject: [PATCH] debug --- src/test.cu | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/test.cu b/src/test.cu index 9eb9b498..777b4b39 100644 --- a/src/test.cu +++ b/src/test.cu @@ -20,15 +20,19 @@ int main(int argc, char** argv) { Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0"); Expr b = param(shape={1, LABEL_SIZE}, name="b0"); - + Expr z = dot(x, w) + b; Expr lr = softmax(z, axis=1, name="pred"); Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); - //cerr << "lr=" << Debug(lr.val().shape()) << endl; + //cerr << "x=" << Debug(lr.val().shape()) << endl; int numofdata; - vector images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); - vector labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); + //vector images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); + //vector labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); + vector images = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", numofdata, IMAGE_SIZE); + vector labels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", numofdata, LABEL_SIZE); + + cerr << "images=" << images.size() << " labels=" << labels.size() << endl; cerr << "numofdata=" << numofdata << endl; @@ -38,14 +42,20 @@ int main(int argc, char** argv) { tx.Load(images); ty.Load(labels); - cerr << "tx=" << Debug(tx.shape()) << endl; - cerr << "ty=" << Debug(ty.shape()) << endl; + //cerr << "tx=" << Debug(tx.shape()) << endl; + //cerr << "ty=" << Debug(ty.shape()) << endl; x = tx; y = ty; + cerr << "x=" << Debug(x.val().shape()) << endl; + cerr << "y=" << Debug(y.val().shape()) << endl; + + graph.forward(500); + cerr << "w=" << Debug(w.val().shape()) << endl; + cerr << "b=" << Debug(b.val().shape()) << endl; std::cerr << "z: " << Debug(z.val().shape()) << endl; std::cerr << "lr: " << Debug(lr.val().shape()) << endl; std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ;