From 7c63606b758158b3985d86d4f52ae608d1690da0 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Fri, 16 Sep 2016 14:16:09 +0200 Subject: [PATCH] move cuda device into graph --- src/validate_mnist.cu | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 449fc193..01fb4c50 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -10,7 +10,7 @@ const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; int BATCH_SIZE = 10000; -ExpressionGraph build_graph() { +ExpressionGraph build_graph(int cudaDevice) { std::cerr << "Loading model params..."; NpzConverter converter("../scripts/test_model_single/model.npz"); @@ -22,7 +22,7 @@ ExpressionGraph build_graph() { std::cerr << "Building model..."; - ExpressionGraph g(0); + ExpressionGraph g(cudaDevice); auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); @@ -46,15 +46,13 @@ ExpressionGraph build_graph() { } int main(int argc, char** argv) { - - cudaSetDevice(0); - + std::cerr << "Loading test set..."; std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE); std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE); std::cerr << "Done." << std::endl; - ExpressionGraph g = build_graph(); + ExpressionGraph g = build_graph(0); Tensor xt({BATCH_SIZE, IMAGE_SIZE}); Tensor yt({BATCH_SIZE, LABEL_SIZE});