mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
correctly calculate batch size
This commit is contained in:
parent
eeb18f33cd
commit
52dcfdfce5
11
src/sgd.cu
11
src/sgd.cu
@ -39,10 +39,12 @@ void SGD::Run()
|
|||||||
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
|
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
|
||||||
std::cerr << "Starting epoch #" << numEpoch << std::endl;
|
std::cerr << "Starting epoch #" << numEpoch << std::endl;
|
||||||
size_t startId = 0;
|
size_t startId = 0;
|
||||||
size_t endId = startId + maxBatchSize_;
|
|
||||||
|
|
||||||
while (endId < numExamples) {
|
while (startId < numExamples) {
|
||||||
PrepareBatch(startId, endId, maxBatchSize_, shuffle, xt, yt);
|
size_t batchSize = std::min(maxBatchSize_, numExamples - startId);
|
||||||
|
size_t endId = startId + batchSize;
|
||||||
|
|
||||||
|
PrepareBatch(startId, endId, batchSize, shuffle, xt, yt);
|
||||||
*inX_ = xt;
|
*inX_ = xt;
|
||||||
*inY_ = yt;
|
*inY_ = yt;
|
||||||
|
|
||||||
@ -52,7 +54,6 @@ void SGD::Run()
|
|||||||
UpdateModel();
|
UpdateModel();
|
||||||
|
|
||||||
startId += maxBatchSize_;
|
startId += maxBatchSize_;
|
||||||
endId += maxBatchSize_;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -86,6 +87,7 @@ void SGD::PrepareBatch(
|
|||||||
std::vector<float> x(batchSize * numFeatures_);
|
std::vector<float> x(batchSize * numFeatures_);
|
||||||
std::vector<float> y(batchSize * numClasses_);
|
std::vector<float> y(batchSize * numClasses_);
|
||||||
|
|
||||||
|
//cerr << "batchSize=" << batchSize << endl;
|
||||||
/*
|
/*
|
||||||
cerr << "startId=" << startId
|
cerr << "startId=" << startId
|
||||||
<< " " << endId
|
<< " " << endId
|
||||||
@ -116,6 +118,7 @@ void SGD::PrepareBatch(
|
|||||||
<< " " << startYDataId << "-" << endYDataId
|
<< " " << startYDataId << "-" << endYDataId
|
||||||
<< endl;
|
<< endl;
|
||||||
*/
|
*/
|
||||||
|
|
||||||
std::copy(xData_.begin() + startXDataId,
|
std::copy(xData_.begin() + startXDataId,
|
||||||
xData_.begin() + endXDataId,
|
xData_.begin() + endXDataId,
|
||||||
x.begin() + startXId);
|
x.begin() + startXId);
|
||||||
|
Loading…
Reference in New Issue
Block a user