From 00d688b9babe22874c173cea8a23960c975fedaa Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 15 Sep 2016 17:03:45 +0100 Subject: [PATCH] shuffling doesn't crash --- src/sgd.cu | 65 ++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/src/sgd.cu b/src/sgd.cu index f864e6db..26121f2f 100644 --- a/src/sgd.cu +++ b/src/sgd.cu @@ -33,8 +33,8 @@ void SGD::Run() Tensor xt({(int)maxBatchSize_, (int)numExamples}, 0.0f); Tensor yt({(int)maxBatchSize_, (int)numClasses_}, 0.0f); - //vector shuffle = CreateShuffle(numExamples); - vector shuffle; + vector shuffle = CreateShuffle(numExamples); + //vector shuffle; for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) { std::cerr << "Starting epoch #" << numEpoch << std::endl; @@ -59,13 +59,14 @@ void SGD::Run() std::vector SGD::CreateShuffle(size_t numExamples) const { vector ret(numExamples); - std::iota(ret.begin(), ret.end(), 1); + std::iota(ret.begin(), ret.end(), 0); std::random_shuffle ( ret.begin(), ret.end() ); - + /* + cerr << "shuffled" << endl; for (size_t i = 0; i < ret.size(); ++i) { - cerr << ret[i] << " "; + cerr << ret[i] << " "; } - + */ return ret; } @@ -76,37 +77,57 @@ void SGD::PrepareBatch( const std::vector &shuffle, Tensor& xt, Tensor& yt) { - + /* std::vector x(xData_.begin() + startId * numFeatures_, xData_.begin() + endId * numFeatures_); std::vector y(yData_.begin() + startId * numClasses_, yData_.begin() + endId * numClasses_); - /* + */ std::vector x(batchSize * numFeatures_); std::vector y(batchSize * numClasses_); - - std::vector::iterator startXIter = x.begin(); - std::vector::iterator startYIter = y.begin(); - + + /* + cerr << "startId=" << startId + << " " << endId + << " " << batchSize + << endl; + cerr << "numExamples=" << shuffle.size() << endl; + cerr << "numFeatures_=" << numFeatures_ << " " << numClasses_ << endl; + cerr << "sizes=" << x.size() + << " " << y.size() + << " " << xData_.size() + << " " << yData_.size() + << endl; + */ + size_t startXId = 0; + size_t startYId = 0; + for (size_t i = startId; i < endId; ++i) { - size_t startXDataId = i * numFeatures_; - size_t startYDataId = i * numClasses_; - - size_t endXDataId = startXDataId + batchSize * numFeatures_; - size_t endYDataId = startYDataId + batchSize * numClasses_; + size_t ind = shuffle[i]; + size_t startXDataId = ind * numFeatures_; + size_t startYDataId = ind * numClasses_; + size_t endXDataId = startXDataId + numFeatures_; + size_t endYDataId = startYDataId + numClasses_; + /* + cerr << "i=" << i + << " " << ind + << " " << startXDataId << "-" << endXDataId + << " " << startYDataId << "-" << endYDataId + << endl; + */ std::copy(xData_.begin() + startXDataId, xData_.begin() + endXDataId, - startXIter); + x.begin() + startXId); std::copy(yData_.begin() + startYDataId, yData_.begin() + endYDataId, - startYIter); + y.begin() + startYId); - startXIter += batchSize * numFeatures_; - startYIter += batchSize * numClasses_; + startXId += numFeatures_; + startYId += numClasses_; } - */ + xt.set(x); yt.set(y); }