From f044b8dfbb8084579d0af232870805bef0242311 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 15 Sep 2016 16:14:06 +0200 Subject: [PATCH] Shuffle data --- src/sgd.cu | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- src/sgd.h | 14 ++++++++++---- 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/src/sgd.cu b/src/sgd.cu index 469d0976..0213f6d5 100644 --- a/src/sgd.cu +++ b/src/sgd.cu @@ -1,6 +1,10 @@ +#include +#include #include "sgd.h" #include "thrust_functions.h" +using namespace std; + namespace marian { SGD::SGD(Expr& cost_func, Expr& inX, Expr& inY, const std::vector params, float eta, @@ -22,17 +26,21 @@ SGD::SGD(Expr& cost_func, Expr& inX, Expr& inY, void SGD::Run() { + std::srand ( unsigned ( std::time(0) ) ); + size_t numExamples = xData_.size()/ numFeatures_; Tensor xt({(int)batchSize_, (int)numExamples}, 0.0f); Tensor yt({(int)batchSize_, (int)numClasses_}, 0.0f); + vector shuffle = CreateShuffle(numExamples); + for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) { std::cerr << "Starting epoch #" << numEpoch << std::endl; size_t startId = 0; size_t endId = startId + batchSize_; while (endId < numExamples) { - PrepareBatch(startId, endId, xt, yt); + PrepareBatch(startId, batchSize_, shuffle, xt, yt); *inX_ = xt; *inY_ = yt; @@ -47,11 +55,55 @@ void SGD::Run() } } -void SGD::PrepareBatch(size_t startId, size_t endId, Tensor& xt, Tensor& yt) { +std::vector SGD::CreateShuffle(size_t numExamples) const { + vector ret(numExamples); + std::iota(ret.begin(), ret.end(), 1); + std::random_shuffle ( ret.begin(), ret.end() ); + + for (size_t i = 0; i < ret.size(); ++i) { + cerr << ret[i] << " "; + } + + return ret; +} + +void SGD::PrepareBatch( + size_t startId, + size_t batchSize, + const std::vector &shuffle, + Tensor& xt, + Tensor& yt) { + /* std::vector x(xData_.begin() + startId * numFeatures_, xData_.begin() + endId * numFeatures_); std::vector y(yData_.begin() + startId * numClasses_, yData_.begin() + endId * numClasses_); + */ + std::vector x(batchSize * numFeatures_); + std::vector y(batchSize * numClasses_); + + std::vector::iterator startXIter = x.begin(); + std::vector::iterator startYIter = y.begin(); + + size_t endId = startId + batchSize; + for (size_t i = startId; i < endId; ++i) { + size_t startXDataId = i * numFeatures_; + size_t startYDataId = i * numClasses_; + + size_t endXDataId = startXDataId + batchSize * numFeatures_; + size_t endYDataId = startYDataId + batchSize * numClasses_; + + std::copy(xData_.begin() + startXDataId, + xData_.begin() + endXDataId, + startXIter); + + std::copy(yData_.begin() + startYDataId, + yData_.begin() + endYDataId, + startYIter); + + startXIter += batchSize * numFeatures_; + startYIter += batchSize * numClasses_; + } xt.set(x); yt.set(y); diff --git a/src/sgd.h b/src/sgd.h index 17bc038e..fedfa8a5 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -19,10 +19,6 @@ class SGD { void Run(); - void PrepareBatch(size_t startId, size_t endId, Tensor& xt, Tensor& yt); - - void UpdateModel(); - private: std::shared_ptr cost_function_; std::shared_ptr inX_; @@ -35,6 +31,16 @@ class SGD { const size_t numClasses_; const size_t epochs_; const size_t batchSize_; + + std::vector CreateShuffle(size_t numExamples) const; + void PrepareBatch( + size_t startId, + size_t batchSize, + const std::vector &shuffle, + Tensor& xt, + Tensor& yt); + + void UpdateModel(); }; } // namespace marian