mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
move cuda device into graph
This commit is contained in:
parent
517fb6c385
commit
7c63606b75
@ -10,7 +10,7 @@ const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
int BATCH_SIZE = 10000;
|
||||
|
||||
ExpressionGraph build_graph() {
|
||||
ExpressionGraph build_graph(int cudaDevice) {
|
||||
std::cerr << "Loading model params...";
|
||||
NpzConverter converter("../scripts/test_model_single/model.npz");
|
||||
|
||||
@ -22,7 +22,7 @@ ExpressionGraph build_graph() {
|
||||
|
||||
std::cerr << "Building model...";
|
||||
|
||||
ExpressionGraph g(0);
|
||||
ExpressionGraph g(cudaDevice);
|
||||
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
|
||||
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
|
||||
|
||||
@ -47,14 +47,12 @@ ExpressionGraph build_graph() {
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
cudaSetDevice(0);
|
||||
|
||||
std::cerr << "Loading test set...";
|
||||
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE);
|
||||
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
ExpressionGraph g = build_graph();
|
||||
ExpressionGraph g = build_graph(0);
|
||||
|
||||
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
|
||||
Tensor yt({BATCH_SIZE, LABEL_SIZE});
|
||||
|
Loading…
Reference in New Issue
Block a user