diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cb121111..6dc37391 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,6 +5,7 @@ cuda_add_library(marian_lib cnpy/cnpy.cpp exception.cpp expressions.cu + sgd.cu tensor.cu tensor_operators.cu ) diff --git a/src/sgd.cu b/src/sgd.cu new file mode 100644 index 00000000..469d0976 --- /dev/null +++ b/src/sgd.cu @@ -0,0 +1,68 @@ +#include "sgd.h" +#include "thrust_functions.h" + +namespace marian { +SGD::SGD(Expr& cost_func, Expr& inX, Expr& inY, + const std::vector params, float eta, + std::vector& xData, size_t numFeatures, + std::vector& yData, size_t numClasses, + size_t epochs, size_t batchSize) +: cost_function_(&cost_func), + inX_(&inX), + inY_(&inY), + params_(params), + eta_(eta), + xData_(xData), + numFeatures_(numFeatures), + yData_(yData), + numClasses_(numClasses), + epochs_(epochs), + batchSize_(batchSize) +{} + +void SGD::Run() +{ + size_t numExamples = xData_.size()/ numFeatures_; + Tensor xt({(int)batchSize_, (int)numExamples}, 0.0f); + Tensor yt({(int)batchSize_, (int)numClasses_}, 0.0f); + + for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) { + std::cerr << "Starting epoch #" << numEpoch << std::endl; + size_t startId = 0; + size_t endId = startId + batchSize_; + + while (endId < numExamples) { + PrepareBatch(startId, endId, xt, yt); + *inX_ = xt; + *inY_ = yt; + + cost_function_->forward(batchSize_); + cost_function_->backward(); + + UpdateModel(); + + startId += batchSize_; + endId += batchSize_; + } + } +} + +void SGD::PrepareBatch(size_t startId, size_t endId, Tensor& xt, Tensor& yt) { + std::vector x(xData_.begin() + startId * numFeatures_, + xData_.begin() + endId * numFeatures_); + std::vector y(yData_.begin() + startId * numClasses_, + yData_.begin() + endId * numClasses_); + + xt.set(x); + yt.set(y); +} + +void SGD::UpdateModel() { + for (auto& param : params_) { + using namespace thrust::placeholders; + Element(_1 = _1 - eta_ * _2, param->val(), param->grad()); + } +} + +} // namespace + diff --git a/src/sgd.h b/src/sgd.h index 0dab8df0..17bc038e 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -5,6 +5,7 @@ #include "expressions.h" #include "thrust_functions.h" +#include "tensor_operators.h" namespace marian { @@ -14,62 +15,13 @@ class SGD { const std::vector params, float eta, std::vector& xData, size_t numFeatures, std::vector& yData, size_t numClasses, - size_t epochs, size_t batchSize) - : cost_function_(&cost_func), - inX_(&inX), - inY_(&inY), - params_(params), - eta_(eta), - xData_(xData), - numFeatures_(numFeatures), - yData_(yData), - numClasses_(numClasses), - epochs_(epochs), - batchSize_(batchSize) - {} + size_t epochs, size_t batchSize); - void Run() { - size_t numExamples = xData_.size()/ numFeatures_; - Tensor xt({(int)batchSize_, (int)numExamples}, 0.0f); - Tensor yt({(int)batchSize_, (int)numClasses_}, 0.0f); + void Run(); - for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) { - std::cerr << "Starting epoch #" << numEpoch << std::endl; - size_t startId = 0; - size_t endId = startId + batchSize_; + void PrepareBatch(size_t startId, size_t endId, Tensor& xt, Tensor& yt); - while (endId < numExamples) { - PrepareBatch(startId, endId, xt, yt); - *inX_ = xt; - *inY_ = yt; - - cost_function_->forward(batchSize_); - cost_function_->backward(); - - UpdateModel(); - - startId += batchSize_; - endId += batchSize_; - } - } - } - - void PrepareBatch(size_t startId, size_t endId, Tensor& xt, Tensor& yt) { - std::vector x(xData_.begin() + startId * numFeatures_, - xData_.begin() + endId * numFeatures_); - std::vector y(yData_.begin() + startId * numClasses_, - yData_.begin() + endId * numClasses_); - - xt.set(x); - yt.set(y); - } - - void UpdateModel() { - for (auto& param : params_) { - using namespace thrust::placeholders; - Element(_1 = _1 - eta_ * _2, param->val(), param->grad()); - } - } + void UpdateModel(); private: std::shared_ptr cost_function_; diff --git a/src/validate_mnist.cu b/src/validate_mnist.cu index e9b5735d..82f5daca 100644 --- a/src/validate_mnist.cu +++ b/src/validate_mnist.cu @@ -22,7 +22,7 @@ int main(int argc, char** argv) { std::cerr << "Loading model params..."; - NpzConverter converter("../scripts/test_model/model.npz"); + NpzConverter converter("../scripts/test_model_single/model.npz"); std::vector wData, bData; Shape wShape, bShape;