mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Shuffle data
This commit is contained in:
parent
b29628e0b6
commit
f044b8dfbb
56
src/sgd.cu
56
src/sgd.cu
@ -1,6 +1,10 @@
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "sgd.h"
|
||||
#include "thrust_functions.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace marian {
|
||||
SGD::SGD(Expr& cost_func, Expr& inX, Expr& inY,
|
||||
const std::vector<Expr*> params, float eta,
|
||||
@ -22,17 +26,21 @@ SGD::SGD(Expr& cost_func, Expr& inX, Expr& inY,
|
||||
|
||||
void SGD::Run()
|
||||
{
|
||||
std::srand ( unsigned ( std::time(0) ) );
|
||||
|
||||
size_t numExamples = xData_.size()/ numFeatures_;
|
||||
Tensor xt({(int)batchSize_, (int)numExamples}, 0.0f);
|
||||
Tensor yt({(int)batchSize_, (int)numClasses_}, 0.0f);
|
||||
|
||||
vector<size_t> shuffle = CreateShuffle(numExamples);
|
||||
|
||||
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
|
||||
std::cerr << "Starting epoch #" << numEpoch << std::endl;
|
||||
size_t startId = 0;
|
||||
size_t endId = startId + batchSize_;
|
||||
|
||||
while (endId < numExamples) {
|
||||
PrepareBatch(startId, endId, xt, yt);
|
||||
PrepareBatch(startId, batchSize_, shuffle, xt, yt);
|
||||
*inX_ = xt;
|
||||
*inY_ = yt;
|
||||
|
||||
@ -47,11 +55,55 @@ void SGD::Run()
|
||||
}
|
||||
}
|
||||
|
||||
void SGD::PrepareBatch(size_t startId, size_t endId, Tensor& xt, Tensor& yt) {
|
||||
std::vector<size_t> SGD::CreateShuffle(size_t numExamples) const {
|
||||
vector<size_t> ret(numExamples);
|
||||
std::iota(ret.begin(), ret.end(), 1);
|
||||
std::random_shuffle ( ret.begin(), ret.end() );
|
||||
|
||||
for (size_t i = 0; i < ret.size(); ++i) {
|
||||
cerr << ret[i] << " ";
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void SGD::PrepareBatch(
|
||||
size_t startId,
|
||||
size_t batchSize,
|
||||
const std::vector<size_t> &shuffle,
|
||||
Tensor& xt,
|
||||
Tensor& yt) {
|
||||
/*
|
||||
std::vector<float> x(xData_.begin() + startId * numFeatures_,
|
||||
xData_.begin() + endId * numFeatures_);
|
||||
std::vector<float> y(yData_.begin() + startId * numClasses_,
|
||||
yData_.begin() + endId * numClasses_);
|
||||
*/
|
||||
std::vector<float> x(batchSize * numFeatures_);
|
||||
std::vector<float> y(batchSize * numClasses_);
|
||||
|
||||
std::vector<float>::iterator startXIter = x.begin();
|
||||
std::vector<float>::iterator startYIter = y.begin();
|
||||
|
||||
size_t endId = startId + batchSize;
|
||||
for (size_t i = startId; i < endId; ++i) {
|
||||
size_t startXDataId = i * numFeatures_;
|
||||
size_t startYDataId = i * numClasses_;
|
||||
|
||||
size_t endXDataId = startXDataId + batchSize * numFeatures_;
|
||||
size_t endYDataId = startYDataId + batchSize * numClasses_;
|
||||
|
||||
std::copy(xData_.begin() + startXDataId,
|
||||
xData_.begin() + endXDataId,
|
||||
startXIter);
|
||||
|
||||
std::copy(yData_.begin() + startYDataId,
|
||||
yData_.begin() + endYDataId,
|
||||
startYIter);
|
||||
|
||||
startXIter += batchSize * numFeatures_;
|
||||
startYIter += batchSize * numClasses_;
|
||||
}
|
||||
|
||||
xt.set(x);
|
||||
yt.set(y);
|
||||
|
14
src/sgd.h
14
src/sgd.h
@ -19,10 +19,6 @@ class SGD {
|
||||
|
||||
void Run();
|
||||
|
||||
void PrepareBatch(size_t startId, size_t endId, Tensor& xt, Tensor& yt);
|
||||
|
||||
void UpdateModel();
|
||||
|
||||
private:
|
||||
std::shared_ptr<Expr> cost_function_;
|
||||
std::shared_ptr<Expr> inX_;
|
||||
@ -35,6 +31,16 @@ class SGD {
|
||||
const size_t numClasses_;
|
||||
const size_t epochs_;
|
||||
const size_t batchSize_;
|
||||
|
||||
std::vector<size_t> CreateShuffle(size_t numExamples) const;
|
||||
void PrepareBatch(
|
||||
size_t startId,
|
||||
size_t batchSize,
|
||||
const std::vector<size_t> &shuffle,
|
||||
Tensor& xt,
|
||||
Tensor& yt);
|
||||
|
||||
void UpdateModel();
|
||||
};
|
||||
|
||||
} // namespace marian
|
||||
|
Loading…
Reference in New Issue
Block a user