mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
correctly calculate batch size
This commit is contained in:
parent
eeb18f33cd
commit
52dcfdfce5
11
src/sgd.cu
11
src/sgd.cu
@ -39,10 +39,12 @@ void SGD::Run()
|
||||
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
|
||||
std::cerr << "Starting epoch #" << numEpoch << std::endl;
|
||||
size_t startId = 0;
|
||||
size_t endId = startId + maxBatchSize_;
|
||||
|
||||
while (endId < numExamples) {
|
||||
PrepareBatch(startId, endId, maxBatchSize_, shuffle, xt, yt);
|
||||
while (startId < numExamples) {
|
||||
size_t batchSize = std::min(maxBatchSize_, numExamples - startId);
|
||||
size_t endId = startId + batchSize;
|
||||
|
||||
PrepareBatch(startId, endId, batchSize, shuffle, xt, yt);
|
||||
*inX_ = xt;
|
||||
*inY_ = yt;
|
||||
|
||||
@ -52,7 +54,6 @@ void SGD::Run()
|
||||
UpdateModel();
|
||||
|
||||
startId += maxBatchSize_;
|
||||
endId += maxBatchSize_;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -86,6 +87,7 @@ void SGD::PrepareBatch(
|
||||
std::vector<float> x(batchSize * numFeatures_);
|
||||
std::vector<float> y(batchSize * numClasses_);
|
||||
|
||||
//cerr << "batchSize=" << batchSize << endl;
|
||||
/*
|
||||
cerr << "startId=" << startId
|
||||
<< " " << endId
|
||||
@ -116,6 +118,7 @@ void SGD::PrepareBatch(
|
||||
<< " " << startYDataId << "-" << endYDataId
|
||||
<< endl;
|
||||
*/
|
||||
|
||||
std::copy(xData_.begin() + startXDataId,
|
||||
xData_.begin() + endXDataId,
|
||||
x.begin() + startXId);
|
||||
|
Loading…
Reference in New Issue
Block a user