Different staff after train_mnist

This commit is contained in:
Tomasz Dwojak 2016-09-14 14:06:42 +01:00
parent ef6246bad2
commit 67e717d366
8 changed files with 175 additions and 33 deletions

View File

@ -6,12 +6,12 @@
namespace marian {
typedef float Float;
typedef std::vector<size_t> Shape;
typedef std::vector<int> Shape;
const int whatevs{-1};
}
#include "keywords.h"
#include "tensor.h"
// #include "tensor.h"
namespace marian {
class Tensor;

View File

@ -10,7 +10,7 @@ Expr::Expr(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v,
keywords::shape={1,1})) {}
Tensor &Expr::val() {
Tensor Expr::val() {
return pimpl_->val();
}

View File

@ -9,25 +9,25 @@ class Expr {
public:
Expr(Chainable<Tensor>* chainable);
Expr(Float v);
Expr operator=(Tensor t) {
pimpl_->setVal(t);
return *this;
}
Tensor &val();
Tensor val();
Tensor grad();
void forward(size_t batchSize);
void backward();
ChainPtr node();
operator ChainPtr();
std::string Debug() const;
private:
ChainPtr pimpl_;
ChainPtr pimpl_;
};
}

View File

@ -4,41 +4,49 @@
#include <iostream>
#include "expressions.h"
#include "thrust_functions.h"
namespace marian {
class SGD {
public:
SGD(Expr& cost_func, Expr& inX, Expr& inY, float eta, std::vector<std::vector<float>> &xData,
std::vector<float> &yData, size_t numClasses, size_t epochs, size_t batchSize)
: cost_function_(&cost_func),
inX_(&inX),
inY_(&inY),
eta_(eta),
xData_(xData),
yData_(yData),
epochs_(epochs),
batchSize_(batchSize),
numClasses_(numClasses) {}
SGD(Expr& cost_func, Expr& inX, Expr& inY,
const std::vector<Expr*> params, float eta,
std::vector<float>& xData, size_t numFeatures,
std::vector<float>& yData, size_t numClasses,
size_t epochs, size_t batchSize)
: cost_function_(&cost_func),
inX_(&inX),
inY_(&inY),
params_(params),
eta_(eta),
xData_(xData),
numFeatures_(numFeatures),
yData_(yData),
numClasses_(numClasses),
epochs_(epochs),
batchSize_(batchSize)
{}
void run() {
auto numExamples = xData_[0].size();
void Run() {
size_t numExamples = xData_.size()/ numFeatures_;
Tensor xt({(int)batchSize_, (int)numExamples}, 0.0f);
Tensor yt({(int)batchSize_, (int)numClasses_}, 0.0f);
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
std::cerr << "Starting epoch #" << numEpoch << std::endl;
size_t startId = 0;
size_t endId = startId + batchSize_;
while (endId < numExamples) {
prepareBatch(startId, xt, yt);
PrepareBatch(startId, endId, xt, yt);
*inX_ = xt;
*inY_ = yt;
cost_function_->forward(batchSize_);
cost_function_->backward();
updateModel();
UpdateModel();
startId += batchSize_;
endId += batchSize_;
@ -46,22 +54,35 @@ class SGD {
}
}
void prepareBatch(const size_t index, Tensor& xt, Tensor& yt) {
void PrepareBatch(size_t startId, size_t endId, Tensor& xt, Tensor& yt) {
std::vector<float> x(xData_.begin() + startId * numFeatures_,
xData_.begin() + endId * numFeatures_);
std::vector<float> y(yData_.begin() + startId * numClasses_,
yData_.begin() + endId * numClasses_);
xt.Load(x);
yt.Load(y);
}
void updateModel() {
void UpdateModel() {
for (auto& param : params_) {
using namespace thrust::placeholders;
Element(_1 = _1 - eta_ * _2, param->val(), param->grad());
}
}
private:
std::shared_ptr<Expr> cost_function_;
std::shared_ptr<Expr> inX_;
std::shared_ptr<Expr> inY_;
std::vector<Expr*> params_;
const float eta_;
std::vector<std::vector<float>> &xData_;
std::vector<float> &yData_;
std::vector<float>& xData_;
const size_t numFeatures_;
std::vector<float>& yData_;
const size_t numClasses_;
const size_t epochs_;
const size_t batchSize_;
const size_t numClasses_;
};
} // namespace marian

View File

@ -83,6 +83,12 @@ void Tensor::Load(const std::string &path)
Load(hostData.begin(), hostData.begin());
}
void Tensor::Load(const std::vector<float>& data)
{
pimpl_->set(data.begin(), data.end());
}
void Tensor::Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end)
{
pimpl_->set(begin, end);

View File

@ -35,7 +35,7 @@ struct Handles {
const Handles handles;
typedef std::vector<int> Shape;
// typedef std::vector<int> Shape;
inline std::string Debug(const Shape &shape)
{
@ -199,13 +199,13 @@ class Tensor {
typedef TensorImpl<Float>::value_type value_type;
Tensor() {}
Tensor(Shape shape, value_type value = 0) {
Tensor(const Shape& shape, value_type value = 0) {
allocate(shape, value);
}
~Tensor() {}
void allocate(Shape shape, value_type value = 0) {
void allocate(const Shape& shape, value_type value = 0) {
if(!pimpl_)
pimpl_.reset(new TensorImpl<Float>(shape, value));
}
@ -275,6 +275,7 @@ class Tensor {
}
void Load(const std::string &path);
void Load(const std::vector<float>& data);
void Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end);
};

37
src/train_mnist.cu Normal file
View File

@ -0,0 +1,37 @@
#include "marian.h"
#include "mnist.h"
#include "sgd.h"
using namespace std;
int main(int argc, char** argv) {
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
int numofdata;
vector<float> trainImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
vector<float>trainLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
using namespace marian;
using namespace keywords;
Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X");
Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y");
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
std::vector<Expr*> params;
params.push_back(&w);
params.push_back(&b);
auto scores = dot(x, w) + b;
auto lr = softmax_fast(scores, axis=1, name="pred");
auto cost = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
cerr << "lr=" << lr.Debug() << endl;
SGD opt(cost, x, y, params, 0.9, trainImages, IMAGE_SIZE, trainLabels, LABEL_SIZE, 3, 24);
opt.Run();
return 0;
}

77
src/validate_mnist.cu Normal file
View File

@ -0,0 +1,77 @@
#include "marian.h"
#include "mnist.h"
#include "npz_converter.h"
using namespace marian;
using namespace keywords;
int main(int argc, char** argv) {
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
int numofdata;
std::cerr << "Loading test set...";
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
std::vector<float>testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
std::cerr << "\tDone." << std::endl;
std::cerr << "Loading model params...";
NpzConverter converter("../scripts/test_model/model.npz");
std::vector<float> wData;
Shape wShape;
converter.Load("weights", wData, wShape);
std::vector<float> bData;
Shape bShape;
converter.Load("bias", bData, bShape);
auto initW = [&wData](Tensor t) {
thrust::copy(wData.begin(), wData.end(), t.begin());
};
auto initB = [&bData](Tensor t) {
thrust::copy(bData.begin(), bData.end(), t.begin());
};
std::cerr << "\tDone." << std::endl;
Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X");
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW);
Expr b = param(shape={1, LABEL_SIZE}, name="b0", init=initB);
std::cerr << "Building model...";
auto scores = dot(x, w) + b;
auto predict = softmax(scores, axis=1, name="pred");
std::cerr << "\tDone." << std::endl;
Tensor xt({numofdata, IMAGE_SIZE});
xt.Load(testImages);
predict.forward(numofdata);
auto results = predict.val();
size_t acc = 0;
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
size_t correct = 0;
size_t predicted = 0;
for (size_t j = 0; j < LABEL_SIZE; ++j) {
if (testLabels[i+j]) correct = j;
if (results[i + j] > results[i + predicted]) predicted = j;
}
acc += (correct == predicted);
std::cerr << "corect: " << correct << " | " << predicted << "(";
for (size_t j = 0; j < LABEL_SIZE; ++j) {
std::cerr << results[i+j] << " ";
}
std::cerr << std::endl;
}
std::cerr << "ACC: " << float(acc)/numofdata << std::endl;
return 0;
}