diff --git a/src/sgd.cu b/src/sgd.cu index 0213f6d5..d46ece86 100644 --- a/src/sgd.cu +++ b/src/sgd.cu @@ -21,7 +21,7 @@ SGD::SGD(Expr& cost_func, Expr& inX, Expr& inY, yData_(yData), numClasses_(numClasses), epochs_(epochs), - batchSize_(batchSize) + maxBatchSize_(batchSize) {} void SGD::Run() @@ -29,28 +29,28 @@ void SGD::Run() std::srand ( unsigned ( std::time(0) ) ); size_t numExamples = xData_.size()/ numFeatures_; - Tensor xt({(int)batchSize_, (int)numExamples}, 0.0f); - Tensor yt({(int)batchSize_, (int)numClasses_}, 0.0f); + Tensor xt({(int)maxBatchSize_, (int)numExamples}, 0.0f); + Tensor yt({(int)maxBatchSize_, (int)numClasses_}, 0.0f); vector shuffle = CreateShuffle(numExamples); for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) { std::cerr << "Starting epoch #" << numEpoch << std::endl; size_t startId = 0; - size_t endId = startId + batchSize_; + size_t endId = startId + maxBatchSize_; while (endId < numExamples) { - PrepareBatch(startId, batchSize_, shuffle, xt, yt); + PrepareBatch(startId, endId, maxBatchSize_, shuffle, xt, yt); *inX_ = xt; *inY_ = yt; - cost_function_->forward(batchSize_); + cost_function_->forward(maxBatchSize_); cost_function_->backward(); UpdateModel(); - startId += batchSize_; - endId += batchSize_; + startId += maxBatchSize_; + endId += maxBatchSize_; } } } @@ -69,23 +69,23 @@ std::vector SGD::CreateShuffle(size_t numExamples) const { void SGD::PrepareBatch( size_t startId, + size_t endId, size_t batchSize, const std::vector &shuffle, Tensor& xt, Tensor& yt) { - /* + std::vector x(xData_.begin() + startId * numFeatures_, xData_.begin() + endId * numFeatures_); std::vector y(yData_.begin() + startId * numClasses_, yData_.begin() + endId * numClasses_); - */ + /* std::vector x(batchSize * numFeatures_); std::vector y(batchSize * numClasses_); std::vector::iterator startXIter = x.begin(); std::vector::iterator startYIter = y.begin(); - size_t endId = startId + batchSize; for (size_t i = startId; i < endId; ++i) { size_t startXDataId = i * numFeatures_; size_t startYDataId = i * numClasses_; @@ -104,7 +104,7 @@ void SGD::PrepareBatch( startXIter += batchSize * numFeatures_; startYIter += batchSize * numClasses_; } - + */ xt.set(x); yt.set(y); } diff --git a/src/sgd.h b/src/sgd.h index fedfa8a5..c5ea8dbc 100644 --- a/src/sgd.h +++ b/src/sgd.h @@ -30,11 +30,12 @@ class SGD { std::vector& yData_; const size_t numClasses_; const size_t epochs_; - const size_t batchSize_; + const size_t maxBatchSize_; std::vector CreateShuffle(size_t numExamples) const; void PrepareBatch( size_t startId, + size_t endId, size_t batchSize, const std::vector &shuffle, Tensor& xt,