diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 449fc193..01fb4c50 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -10,7 +10,7 @@ const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; int BATCH_SIZE = 10000; -ExpressionGraph build_graph() { +ExpressionGraph build_graph(int cudaDevice) { std::cerr << "Loading model params..."; NpzConverter converter("../scripts/test_model_single/model.npz"); @@ -22,7 +22,7 @@ ExpressionGraph build_graph() { std::cerr << "Building model..."; - ExpressionGraph g(0); + ExpressionGraph g(cudaDevice); auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); @@ -46,15 +46,13 @@ ExpressionGraph build_graph() { } int main(int argc, char** argv) { - - cudaSetDevice(0); - + std::cerr << "Loading test set..."; std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE); std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE); std::cerr << "Done." << std::endl; - ExpressionGraph g = build_graph(); + ExpressionGraph g = build_graph(0); Tensor xt({BATCH_SIZE, IMAGE_SIZE}); Tensor yt({BATCH_SIZE, LABEL_SIZE});