don't use shuffle

This commit is contained in:
Hieu Hoang 2016-09-15 16:31:12 +02:00
parent f044b8dfbb
commit 88678e59bc
2 changed files with 14 additions and 13 deletions

View File

@ -21,7 +21,7 @@ SGD::SGD(Expr& cost_func, Expr& inX, Expr& inY,
yData_(yData), yData_(yData),
numClasses_(numClasses), numClasses_(numClasses),
epochs_(epochs), epochs_(epochs),
batchSize_(batchSize) maxBatchSize_(batchSize)
{} {}
void SGD::Run() void SGD::Run()
@ -29,28 +29,28 @@ void SGD::Run()
std::srand ( unsigned ( std::time(0) ) ); std::srand ( unsigned ( std::time(0) ) );
size_t numExamples = xData_.size()/ numFeatures_; size_t numExamples = xData_.size()/ numFeatures_;
Tensor xt({(int)batchSize_, (int)numExamples}, 0.0f); Tensor xt({(int)maxBatchSize_, (int)numExamples}, 0.0f);
Tensor yt({(int)batchSize_, (int)numClasses_}, 0.0f); Tensor yt({(int)maxBatchSize_, (int)numClasses_}, 0.0f);
vector<size_t> shuffle = CreateShuffle(numExamples); vector<size_t> shuffle = CreateShuffle(numExamples);
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) { for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
std::cerr << "Starting epoch #" << numEpoch << std::endl; std::cerr << "Starting epoch #" << numEpoch << std::endl;
size_t startId = 0; size_t startId = 0;
size_t endId = startId + batchSize_; size_t endId = startId + maxBatchSize_;
while (endId < numExamples) { while (endId < numExamples) {
PrepareBatch(startId, batchSize_, shuffle, xt, yt); PrepareBatch(startId, endId, maxBatchSize_, shuffle, xt, yt);
*inX_ = xt; *inX_ = xt;
*inY_ = yt; *inY_ = yt;
cost_function_->forward(batchSize_); cost_function_->forward(maxBatchSize_);
cost_function_->backward(); cost_function_->backward();
UpdateModel(); UpdateModel();
startId += batchSize_; startId += maxBatchSize_;
endId += batchSize_; endId += maxBatchSize_;
} }
} }
} }
@ -69,23 +69,23 @@ std::vector<size_t> SGD::CreateShuffle(size_t numExamples) const {
void SGD::PrepareBatch( void SGD::PrepareBatch(
size_t startId, size_t startId,
size_t endId,
size_t batchSize, size_t batchSize,
const std::vector<size_t> &shuffle, const std::vector<size_t> &shuffle,
Tensor& xt, Tensor& xt,
Tensor& yt) { Tensor& yt) {
/*
std::vector<float> x(xData_.begin() + startId * numFeatures_, std::vector<float> x(xData_.begin() + startId * numFeatures_,
xData_.begin() + endId * numFeatures_); xData_.begin() + endId * numFeatures_);
std::vector<float> y(yData_.begin() + startId * numClasses_, std::vector<float> y(yData_.begin() + startId * numClasses_,
yData_.begin() + endId * numClasses_); yData_.begin() + endId * numClasses_);
*/ /*
std::vector<float> x(batchSize * numFeatures_); std::vector<float> x(batchSize * numFeatures_);
std::vector<float> y(batchSize * numClasses_); std::vector<float> y(batchSize * numClasses_);
std::vector<float>::iterator startXIter = x.begin(); std::vector<float>::iterator startXIter = x.begin();
std::vector<float>::iterator startYIter = y.begin(); std::vector<float>::iterator startYIter = y.begin();
size_t endId = startId + batchSize;
for (size_t i = startId; i < endId; ++i) { for (size_t i = startId; i < endId; ++i) {
size_t startXDataId = i * numFeatures_; size_t startXDataId = i * numFeatures_;
size_t startYDataId = i * numClasses_; size_t startYDataId = i * numClasses_;
@ -104,7 +104,7 @@ void SGD::PrepareBatch(
startXIter += batchSize * numFeatures_; startXIter += batchSize * numFeatures_;
startYIter += batchSize * numClasses_; startYIter += batchSize * numClasses_;
} }
*/
xt.set(x); xt.set(x);
yt.set(y); yt.set(y);
} }

View File

@ -30,11 +30,12 @@ class SGD {
std::vector<float>& yData_; std::vector<float>& yData_;
const size_t numClasses_; const size_t numClasses_;
const size_t epochs_; const size_t epochs_;
const size_t batchSize_; const size_t maxBatchSize_;
std::vector<size_t> CreateShuffle(size_t numExamples) const; std::vector<size_t> CreateShuffle(size_t numExamples) const;
void PrepareBatch( void PrepareBatch(
size_t startId, size_t startId,
size_t endId,
size_t batchSize, size_t batchSize,
const std::vector<size_t> &shuffle, const std::vector<size_t> &shuffle,
Tensor& xt, Tensor& xt,