diff --git a/src/expression_graph.cu b/src/expression_graph.cu index 61f8d2b5..22de1c89 100644 --- a/src/expression_graph.cu +++ b/src/expression_graph.cu @@ -37,5 +37,14 @@ std::string Expr::Debug() const strm << marian::Debug(shape); return strm.str(); } - + +/////////////////////////////////////////////////////// +ExpressionGraph::ExpressionGraph(int cudaDevice) +: stack_(new ChainableStack) +{ + std::srand (time(NULL)); + cudaSetDevice(0); + +} + } diff --git a/src/expression_graph.h b/src/expression_graph.h index f0d5f233..8b7d74f4 100644 --- a/src/expression_graph.h +++ b/src/expression_graph.h @@ -38,9 +38,7 @@ class Expr { class ExpressionGraph { public: - ExpressionGraph() - : stack_(new ChainableStack) - {} + ExpressionGraph(int cudaDevice); void forward(size_t batchSize) { for(auto&& v : *stack_) { diff --git a/src/sgd.cu b/src/sgd.cu index 5fe69138..598d9f6b 100644 --- a/src/sgd.cu +++ b/src/sgd.cu @@ -23,8 +23,6 @@ SGD::SGD(ExpressionGraph& g, float eta, void SGD::Run() { - std::srand ( unsigned ( std::time(0) ) ); - size_t numExamples = xData_.size()/ numFeatures_; Tensor xt({(int)maxBatchSize_, (int)numExamples}, 0.0f); Tensor yt({(int)maxBatchSize_, (int)numClasses_}, 0.0f); diff --git a/src/test.cu b/src/test.cu index 7da85c9d..ba798715 100644 --- a/src/test.cu +++ b/src/test.cu @@ -4,7 +4,6 @@ #include "vocab.h" int main(int argc, char** argv) { - cudaSetDevice(0); using namespace std; using namespace marian; @@ -22,7 +21,7 @@ int main(int argc, char** argv) { std::vector Y; std::vector H; - ExpressionGraph g; + ExpressionGraph g(0); for (int t = 0; t < num_inputs; ++t) { X.emplace_back(g.input(shape={batch_size, input_size})); diff --git a/src/train_mnist.cu b/src/train_mnist.cu index 64ccf564..09e08d15 100644 --- a/src/train_mnist.cu +++ b/src/train_mnist.cu @@ -16,7 +16,7 @@ int main(int argc, char** argv) { using namespace marian; using namespace keywords; - ExpressionGraph g; + ExpressionGraph g(0); Expr x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); Expr y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index cbd3e0a3..36f0135a 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -10,7 +10,7 @@ const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; int BATCH_SIZE = 10000; -ExpressionGraph build_graph() { +ExpressionGraph build_graph(int cudaDevice) { std::cerr << "Loading model params..."; NpzConverter converter("../scripts/test_model_single/model.npz"); @@ -22,7 +22,7 @@ ExpressionGraph build_graph() { std::cerr << "Building model..."; - ExpressionGraph g; + ExpressionGraph g(cudaDevice); auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); @@ -46,15 +46,12 @@ ExpressionGraph build_graph() { } int main(int argc, char** argv) { - - cudaSetDevice(1); - std::cerr << "Loading test set..."; std::vector testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE); std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE); std::cerr << "Done." << std::endl; - ExpressionGraph g = build_graph(); + ExpressionGraph g = build_graph(1); Tensor xt({BATCH_SIZE, IMAGE_SIZE}); Tensor yt({BATCH_SIZE, LABEL_SIZE}); diff --git a/src/validate_mnist_batch.cu b/src/validate_mnist_batch.cu index 50ab97b5..754d254c 100644 --- a/src/validate_mnist_batch.cu +++ b/src/validate_mnist_batch.cu @@ -7,9 +7,7 @@ using namespace marian; using namespace keywords; int main(int argc, char** argv) { - - cudaSetDevice(0); - + const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; const size_t BATCH_SIZE = 24; @@ -59,7 +57,7 @@ int main(int argc, char** argv) { std::cerr << "\tDone." << std::endl; - ExpressionGraph g; + ExpressionGraph g(0); auto x = g.input(shape={whatevs, IMAGE_SIZE}, name="X"); auto y = g.input(shape={whatevs, LABEL_SIZE}, name="Y"); diff --git a/src/vocab.cpp b/src/vocab.cpp index 705c21b2..ea3437c9 100644 --- a/src/vocab.cpp +++ b/src/vocab.cpp @@ -24,6 +24,21 @@ inline std::vector Tokenize(const std::string& str, return tokens; } //////////////////////////////////////////////////////// +size_t Vocab::GetUNK() const +{ + return std::numeric_limits::max(); +} + +size_t Vocab::GetPad() const +{ + return std::numeric_limits::max() - 1; +} + +size_t Vocab::GetEOS() const +{ + return std::numeric_limits::max() - 2; +} + size_t Vocab::GetOrCreate(const std::string &word) { diff --git a/src/vocab.h b/src/vocab.h index 5e055511..0cf42dac 100644 --- a/src/vocab.h +++ b/src/vocab.h @@ -10,6 +10,9 @@ public: size_t GetOrCreate(const std::string &word); std::vector ProcessSentence(const std::string &sentence); + size_t GetUNK() const; + size_t GetPad() const; + size_t GetEOS() const; protected: typedef std::unordered_map Coll; Coll coll_;