correctly calculate batch size

This commit is contained in:
Hieu Hoang 2016-09-15 17:52:14 +01:00
parent eeb18f33cd
commit 52dcfdfce5

View File

@ -39,10 +39,12 @@ void SGD::Run()
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) { for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
std::cerr << "Starting epoch #" << numEpoch << std::endl; std::cerr << "Starting epoch #" << numEpoch << std::endl;
size_t startId = 0; size_t startId = 0;
size_t endId = startId + maxBatchSize_;
while (endId < numExamples) { while (startId < numExamples) {
PrepareBatch(startId, endId, maxBatchSize_, shuffle, xt, yt); size_t batchSize = std::min(maxBatchSize_, numExamples - startId);
size_t endId = startId + batchSize;
PrepareBatch(startId, endId, batchSize, shuffle, xt, yt);
*inX_ = xt; *inX_ = xt;
*inY_ = yt; *inY_ = yt;
@ -52,7 +54,6 @@ void SGD::Run()
UpdateModel(); UpdateModel();
startId += maxBatchSize_; startId += maxBatchSize_;
endId += maxBatchSize_;
} }
} }
} }
@ -86,6 +87,7 @@ void SGD::PrepareBatch(
std::vector<float> x(batchSize * numFeatures_); std::vector<float> x(batchSize * numFeatures_);
std::vector<float> y(batchSize * numClasses_); std::vector<float> y(batchSize * numClasses_);
//cerr << "batchSize=" << batchSize << endl;
/* /*
cerr << "startId=" << startId cerr << "startId=" << startId
<< " " << endId << " " << endId
@ -116,6 +118,7 @@ void SGD::PrepareBatch(
<< " " << startYDataId << "-" << endYDataId << " " << startYDataId << "-" << endYDataId
<< endl; << endl;
*/ */
std::copy(xData_.begin() + startXDataId, std::copy(xData_.begin() + startXDataId,
xData_.begin() + endXDataId, xData_.begin() + endXDataId,
x.begin() + startXId); x.begin() + startXId);