diff --git a/src/sgd.cu b/src/sgd.cu index 26121f2f..0a276835 100644 --- a/src/sgd.cu +++ b/src/sgd.cu @@ -39,10 +39,12 @@ void SGD::Run() for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) { std::cerr << "Starting epoch #" << numEpoch << std::endl; size_t startId = 0; - size_t endId = startId + maxBatchSize_; - while (endId < numExamples) { - PrepareBatch(startId, endId, maxBatchSize_, shuffle, xt, yt); + while (startId < numExamples) { + size_t batchSize = std::min(maxBatchSize_, numExamples - startId); + size_t endId = startId + batchSize; + + PrepareBatch(startId, endId, batchSize, shuffle, xt, yt); *inX_ = xt; *inY_ = yt; @@ -52,7 +54,6 @@ void SGD::Run() UpdateModel(); startId += maxBatchSize_; - endId += maxBatchSize_; } } } @@ -86,6 +87,7 @@ void SGD::PrepareBatch( std::vector x(batchSize * numFeatures_); std::vector y(batchSize * numClasses_); + //cerr << "batchSize=" << batchSize << endl; /* cerr << "startId=" << startId << " " << endId @@ -116,6 +118,7 @@ void SGD::PrepareBatch( << " " << startYDataId << "-" << endYDataId << endl; */ + std::copy(xData_.begin() + startXDataId, xData_.begin() + endXDataId, x.begin() + startXId);