diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index f4fb76be..56023304 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -12,13 +12,13 @@ int main(int argc, char** argv) { const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; + int BATCH_SIZE = 10000; std::cerr << "Loading test set..."; - std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); - std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); + std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE); + std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE); std::cerr << "Done." << std::endl; - const size_t BATCH_SIZE = testLabels.size(); std::cerr << "Loading model params..."; NpzConverter converter("../scripts/test_model/model.npz"); @@ -43,19 +43,18 @@ int main(int argc, char** argv) { axis=1, name="pred"); auto graph = -mean(sum(y * log(predict), axis=1), axis=0, name="cost"); - std::cerr << "Done." << std::endl; - Tensor xt({numofdata, IMAGE_SIZE}); - Tensor yt({numofdata, LABEL_SIZE}); + Tensor xt({BATCH_SIZE, IMAGE_SIZE}); + Tensor yt({BATCH_SIZE, LABEL_SIZE}); x = xt << testImages; y = yt << testLabels; graph.forward(BATCH_SIZE); - auto results = predict.val(); graph.backward(); + auto results = predict.val(); std::vector resultsv(results.size()); resultsv << results;