correctly calculate batch size

This commit is contained in:
Hieu Hoang 2016-09-15 17:52:14 +01:00
parent eeb18f33cd
commit 52dcfdfce5

View File

@ -39,10 +39,12 @@ void SGD::Run()
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
std::cerr << "Starting epoch #" << numEpoch << std::endl;
size_t startId = 0;
size_t endId = startId + maxBatchSize_;
while (endId < numExamples) {
PrepareBatch(startId, endId, maxBatchSize_, shuffle, xt, yt);
while (startId < numExamples) {
size_t batchSize = std::min(maxBatchSize_, numExamples - startId);
size_t endId = startId + batchSize;
PrepareBatch(startId, endId, batchSize, shuffle, xt, yt);
*inX_ = xt;
*inY_ = yt;
@ -52,7 +54,6 @@ void SGD::Run()
UpdateModel();
startId += maxBatchSize_;
endId += maxBatchSize_;
}
}
}
@ -86,6 +87,7 @@ void SGD::PrepareBatch(
std::vector<float> x(batchSize * numFeatures_);
std::vector<float> y(batchSize * numClasses_);
//cerr << "batchSize=" << batchSize << endl;
/*
cerr << "startId=" << startId
<< " " << endId
@ -116,6 +118,7 @@ void SGD::PrepareBatch(
<< " " << startYDataId << "-" << endYDataId
<< endl;
*/
std::copy(xData_.begin() + startXDataId,
xData_.begin() + endXDataId,
x.begin() + startXId);