diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 0694509c..401d4d1f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,7 +5,6 @@ cuda_add_library(marian_lib cnpy/cnpy.cpp exception.cpp expression_graph.cu - sgd.cu tensor.cu tensor_operators.cu expression_operators.cu diff --git a/src/expression_graph.cu b/src/expression_graph.cu index 22de1c89..52a68893 100644 --- a/src/expression_graph.cu +++ b/src/expression_graph.cu @@ -39,12 +39,12 @@ std::string Expr::Debug() const } /////////////////////////////////////////////////////// -ExpressionGraph::ExpressionGraph(int cudaDevice) -: stack_(new ChainableStack) -{ - std::srand (time(NULL)); - cudaSetDevice(0); - -} +//ExpressionGraph::ExpressionGraph(int cudaDevice) +//: stack_(new ChainableStack) +//{ +// std::srand (time(NULL)); +// cudaSetDevice(0); +// +//} } diff --git a/src/expression_graph.h b/src/expression_graph.h index a092fcb1..df99e652 100644 --- a/src/expression_graph.h +++ b/src/expression_graph.h @@ -38,9 +38,14 @@ class Expr { class ExpressionGraph { public: - ExpressionGraph(int cudaDevice); + ExpressionGraph() : stack_(new ChainableStack) {} - void forward(size_t batchSize) { + void backprop(int batchSize) { + forward(batchSize); + backward(); + } + + void forward(int batchSize) { for(auto&& v : *stack_) { v->allocate(batchSize); } @@ -48,6 +53,16 @@ class ExpressionGraph { v->forward(); } + void backward() { + for(auto&& v : *stack_) + v->set_zero_adjoint(); + + typedef typename ChainableStack::reverse_iterator It; + stack_->back()->init_dependent(); + for(It it = stack_->rbegin(); it != stack_->rend(); ++it) + (*it)->backward(); + } + std::string graphviz() { std::stringstream ss; ss << "digraph ExpressionGraph {" << std::endl; @@ -60,19 +75,13 @@ class ExpressionGraph { return ss.str(); } - void backward() { - for(auto&& v : *stack_) - v->set_zero_adjoint(); - - typedef typename ChainableStack::reverse_iterator It; - stack_->back()->init_dependent(); - for(It it = stack_->rbegin(); it != stack_->rend(); ++it) - (*it)->backward(); - } + /*********************************************************/ template inline Expr input(Args ...args) { - return Expr(this, new InputNode(args...)); + Expr e(this, new InputNode(args...)); + inputs_.emplace_back(e); + return e; } template @@ -117,14 +126,20 @@ class ExpressionGraph { named_.emplace(name, e); } + std::vector& inputs() { + return inputs_; + } + std::vector& params() { return params_; } private: ChainableStackPtr stack_; + std::map named_; std::vector params_; + std::vector inputs_; }; } diff --git a/src/sgd.cu b/src/sgd.cu deleted file mode 100644 index 598d9f6b..00000000 --- a/src/sgd.cu +++ /dev/null @@ -1,140 +0,0 @@ -#include -#include -#include -#include "sgd.h" -#include "thrust_functions.h" - -using namespace std; - -namespace marian { -SGD::SGD(ExpressionGraph& g, float eta, - std::vector& xData, size_t numFeatures, - std::vector& yData, size_t numClasses, - size_t epochs, size_t batchSize) -: graph_(g), - eta_(eta), - xData_(xData), - numFeatures_(numFeatures), - yData_(yData), - numClasses_(numClasses), - epochs_(epochs), - maxBatchSize_(batchSize) -{} - -void SGD::Run() -{ - size_t numExamples = xData_.size()/ numFeatures_; - Tensor xt({(int)maxBatchSize_, (int)numExamples}, 0.0f); - Tensor yt({(int)maxBatchSize_, (int)numClasses_}, 0.0f); - - vector shuffle = CreateShuffle(numExamples); - //vector shuffle; - - for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) { - std::cerr << "Starting epoch #" << numEpoch << std::endl; - size_t startId = 0; - - while (startId < numExamples) { - size_t batchSize = std::min(maxBatchSize_, numExamples - startId); - size_t endId = startId + batchSize; - - PrepareBatch(startId, endId, batchSize, shuffle, xt, yt); - graph_["x"] = xt; - graph_["y"] = yt; - - graph_.forward(maxBatchSize_); - graph_.backward(); - - UpdateModel(); - - startId += maxBatchSize_; - } - } -} - -std::vector SGD::CreateShuffle(size_t numExamples) const { - vector ret(numExamples); - std::iota(ret.begin(), ret.end(), 0); - std::random_shuffle ( ret.begin(), ret.end() ); - /* - cerr << "shuffled" << endl; - for (size_t i = 0; i < ret.size(); ++i) { - cerr << ret[i] << " "; - } - */ - return ret; -} - -void SGD::PrepareBatch( - size_t startId, - size_t endId, - size_t batchSize, - const std::vector &shuffle, - Tensor& xt, - Tensor& yt) { - /* - std::vector x(xData_.begin() + startId * numFeatures_, - xData_.begin() + endId * numFeatures_); - std::vector y(yData_.begin() + startId * numClasses_, - yData_.begin() + endId * numClasses_); - */ - std::vector x(batchSize * numFeatures_); - std::vector y(batchSize * numClasses_); - - //cerr << "batchSize=" << batchSize << endl; - /* - cerr << "startId=" << startId - << " " << endId - << " " << batchSize - << endl; - cerr << "numExamples=" << shuffle.size() << endl; - cerr << "numFeatures_=" << numFeatures_ << " " << numClasses_ << endl; - cerr << "sizes=" << x.size() - << " " << y.size() - << " " << xData_.size() - << " " << yData_.size() - << endl; - */ - size_t startXId = 0; - size_t startYId = 0; - - for (size_t i = startId; i < endId; ++i) { - size_t ind = shuffle[i]; - size_t startXDataId = ind * numFeatures_; - size_t startYDataId = ind * numClasses_; - - size_t endXDataId = startXDataId + numFeatures_; - size_t endYDataId = startYDataId + numClasses_; - /* - cerr << "i=" << i - << " " << ind - << " " << startXDataId << "-" << endXDataId - << " " << startYDataId << "-" << endYDataId - << endl; - */ - - std::copy(xData_.begin() + startXDataId, - xData_.begin() + endXDataId, - x.begin() + startXId); - - std::copy(yData_.begin() + startYDataId, - yData_.begin() + endYDataId, - y.begin() + startYId); - - startXId += numFeatures_; - startYId += numClasses_; - } - - xt.set(x); - yt.set(y); -} - -void SGD::UpdateModel() { - for (auto& param : graph_.params()) { - using namespace thrust::placeholders; - Element(_1 -= eta_ * _2, param.val(), param.grad()); - } -} - -} // namespace - diff --git a/src/sgd.h b/src/sgd.h index a99acd75..7d3f3200 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -1,43 +1,48 @@ #pragma once -#include -#include - -#include "expression_graph.h" -#include "thrust_functions.h" +#include +#include #include "tensor_operators.h" namespace marian { -class SGD { +class Sgd { public: - SGD(ExpressionGraph& g, float eta, - std::vector& xData, size_t numFeatures, - std::vector& yData, size_t numClasses, - size_t epochs, size_t batchSize); - - void Run(); - + Sgd(float eta=0.1) : eta_(eta) {} + + void operator()(ExpressionGraph& graph, int batchSize) { + graph.backprop(batchSize); + + for(auto& param : graph.params()) + Element(_1 -= eta_ * _2, param.val(), param.grad()); + } + private: - ExpressionGraph& graph_; - const float eta_; - std::vector& xData_; - const size_t numFeatures_; - std::vector& yData_; - const size_t numClasses_; - const size_t epochs_; - const size_t maxBatchSize_; - - std::vector CreateShuffle(size_t numExamples) const; - void PrepareBatch( - size_t startId, - size_t endId, - size_t batchSize, - const std::vector &shuffle, - Tensor& xt, - Tensor& yt); - - void UpdateModel(); + float eta_; }; -} // namespace marian +class Adagrad { + public: + Adagrad(float eta=0.1) : eta_(eta) {} + + void operator()(ExpressionGraph& graph, int batchSize) { + if(history_.size() < graph.params().size()) + for(auto& param : graph.params()) + history_.emplace_back(Tensor(param.grad().shape(), 0)); + + graph.backprop(batchSize); + + auto it = history_.begin(); + for(auto& param : graph.params()) { + Element(_1 -= eta_ / Sqrt(_2) * _3, param.val(), *it, param.grad()); + Element(_1 += _2 * _2, *it, param.grad()); + it++; + } + } + + private: + float eta_; + std::vector history_; +}; + +} \ No newline at end of file diff --git a/src/tensor.cu b/src/tensor.cu index 0c3e8a3e..3ec4e71e 100644 --- a/src/tensor.cu +++ b/src/tensor.cu @@ -1,8 +1,6 @@ #include #include "tensor.h" -using namespace std; - namespace marian { void Tensor::set(const std::vector& data) diff --git a/src/test.cu b/src/test.cu index 8d7073f4..e352e7f4 100644 --- a/src/test.cu +++ b/src/test.cu @@ -21,7 +21,7 @@ int main(int argc, char** argv) { std::vector Y; std::vector H; - ExpressionGraph g(0); + ExpressionGraph g; for (int t = 0; t < num_inputs; ++t) { X.emplace_back(g.input(shape={batch_size, input_size})); @@ -39,10 +39,9 @@ int main(int argc, char** argv) { string sourceLine, targetLine; while (getline(sourceFile, sourceLine)) { - getline(targetFile, targetLine); - - std::vector sourceIds = sourceVocab.ProcessSentence(sourceLine); - std::vector targetIds = sourceVocab.ProcessSentence(targetLine); + getline(targetFile, targetLine); + std::vector sourceIds = sourceVocab.ProcessSentence(sourceLine); + std::vector targetIds = sourceVocab.ProcessSentence(targetLine); } std::cerr << "Building RNN..." << std::endl; diff --git a/src/thrust_functions.h b/src/thrust_functions.h index 2712fda7..1ab99473 100644 --- a/src/thrust_functions.h +++ b/src/thrust_functions.h @@ -85,6 +85,19 @@ namespace thrust return compose(unary_operator(), _1); } + template + struct unary_sqrt : public thrust::unary_function { + __host__ __device__ + T operator()(const T &x) const { return sqrtf(x); } + }; + + template + __host__ __device__ + actor, actor>> + Sqrt(const actor &_1) { + return compose(unary_operator(), _1); + } + template __host__ __device__ actor, actor, actor>> diff --git a/src/train_mnist.cu b/src/train_mnist.cu index 09e08d15..2dda8fde 100644 --- a/src/train_mnist.cu +++ b/src/train_mnist.cu @@ -11,12 +11,12 @@ int main(int argc, char** argv) { int numofdata; vector trainImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE); - vectortrainLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); + vector trainLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE); using namespace marian; using namespace keywords; - ExpressionGraph g(0); + ExpressionGraph g; Expr x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); Expr y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); @@ -24,16 +24,13 @@ int main(int argc, char** argv) { Expr w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE}), "w"); Expr b = named(g.param(shape={1, LABEL_SIZE}), "b"); - std::vector params; - params.push_back(&w); - params.push_back(&b); - auto scores = dot(x, w) + b; auto lr = softmax_fast(scores); auto cost = named(-mean(sum(y * log(lr), axis=1), axis=0), "cost"); cerr << "lr=" << lr.Debug() << endl; - SGD opt(g, 0.9, trainImages, IMAGE_SIZE, trainLabels, LABEL_SIZE, 3, 24); - opt.Run(); + Adagrad opt; + opt(g, 300); + return 0; } diff --git a/src/validate_encoder_decoder.cu b/src/validate_encoder_decoder.cu index b053c0ea..3b516107 100644 --- a/src/validate_encoder_decoder.cu +++ b/src/validate_encoder_decoder.cu @@ -32,7 +32,7 @@ int main(int argc, char** argv) { int num_inputs = 8; int num_outputs = 6; - ExpressionGraph g(0); + ExpressionGraph g; std::vector X(num_inputs+1); // For the stop symbol. std::vector Y(num_outputs); std::vector H(num_inputs+1); // For the stop symbol. diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index 01fb4c50..f9bc0dcf 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -10,7 +10,7 @@ const size_t IMAGE_SIZE = 784; const size_t LABEL_SIZE = 10; int BATCH_SIZE = 10000; -ExpressionGraph build_graph(int cudaDevice) { +ExpressionGraph build_graph() { std::cerr << "Loading model params..."; NpzConverter converter("../scripts/test_model_single/model.npz"); @@ -22,7 +22,7 @@ ExpressionGraph build_graph(int cudaDevice) { std::cerr << "Building model..."; - ExpressionGraph g(cudaDevice); + ExpressionGraph g; auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x"); auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y"); @@ -52,7 +52,7 @@ int main(int argc, char** argv) { std::vector testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE); std::cerr << "Done." << std::endl; - ExpressionGraph g = build_graph(0); + ExpressionGraph g = build_graph(); Tensor xt({BATCH_SIZE, IMAGE_SIZE}); Tensor yt({BATCH_SIZE, LABEL_SIZE}); diff --git a/src/validate_mnist_batch.cu b/src/validate_mnist_batch.cu index 754d254c..d37e9ca3 100644 --- a/src/validate_mnist_batch.cu +++ b/src/validate_mnist_batch.cu @@ -56,8 +56,7 @@ int main(int argc, char** argv) { std::cerr << "\tDone." << std::endl; - - ExpressionGraph g(0); + ExpressionGraph g; auto x = g.input(shape={whatevs, IMAGE_SIZE}, name="X"); auto y = g.input(shape={whatevs, LABEL_SIZE}, name="Y");