diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 510b8dd4..f4fb76be 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -12,40 +12,31 @@ int main(int argc, char** argv) { const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; - int numofdata; - + std::cerr << "Loading test set..."; std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); - std::cerr << "\tDone." << std::endl; + std::cerr << "Done." << std::endl; + const size_t BATCH_SIZE = testLabels.size(); + std::cerr << "Loading model params..."; NpzConverter converter("../scripts/test_model/model.npz"); - std::vector wData; - Shape wShape; + std::vector wData, bData; + Shape wShape, bShape; converter.Load("weights", wData, wShape); - - std::vector bData; - Shape bShape; converter.Load("bias", bData, bShape); - auto initW = [wData](Tensor t) { - t.set(wData); - }; - - auto initB = [bData](Tensor t) { - t.set(bData); - }; - - std::cerr << "\tDone." << std::endl; - + std::cerr << "Done." << std::endl; auto x = input(shape={whatevs, IMAGE_SIZE}, name="X"); auto y = input(shape={whatevs, LABEL_SIZE}, name="Y"); - auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW); - auto b = param(shape={1, LABEL_SIZE}, name="b0", init=initB); + auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", + init=[wData](Tensor t) { t.set(wData); }); + auto b = param(shape={1, LABEL_SIZE}, name="b0", + init=[bData](Tensor t) {t.set(bData); }); std::cerr << "Building model..."; auto predict = softmax(dot(x, w) + b, @@ -61,7 +52,7 @@ int main(int argc, char** argv) { x = xt << testImages; y = yt << testLabels; - graph.forward(numofdata); + graph.forward(BATCH_SIZE); auto results = predict.val(); graph.backward(); @@ -85,7 +76,7 @@ int main(int argc, char** argv) { //} //std::cerr << ")" << std::endl; } - std::cerr << "ACC: " << float(acc)/numofdata << std::endl; + std::cerr << "Accuracy: " << float(acc)/BATCH_SIZE << std::endl; return 0; }