mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Different staff after train_mnist
This commit is contained in:
parent
ef6246bad2
commit
67e717d366
@ -6,12 +6,12 @@
|
|||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
typedef float Float;
|
typedef float Float;
|
||||||
typedef std::vector<size_t> Shape;
|
typedef std::vector<int> Shape;
|
||||||
const int whatevs{-1};
|
const int whatevs{-1};
|
||||||
}
|
}
|
||||||
|
|
||||||
#include "keywords.h"
|
#include "keywords.h"
|
||||||
#include "tensor.h"
|
// #include "tensor.h"
|
||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
class Tensor;
|
class Tensor;
|
||||||
|
@ -10,7 +10,7 @@ Expr::Expr(Chainable<Tensor>* chainable) : pimpl_(chainable) {}
|
|||||||
Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v,
|
Expr::Expr(Float v) : pimpl_(new ConstantNode(keywords::value=v,
|
||||||
keywords::shape={1,1})) {}
|
keywords::shape={1,1})) {}
|
||||||
|
|
||||||
Tensor &Expr::val() {
|
Tensor Expr::val() {
|
||||||
return pimpl_->val();
|
return pimpl_->val();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,25 +9,25 @@ class Expr {
|
|||||||
public:
|
public:
|
||||||
Expr(Chainable<Tensor>* chainable);
|
Expr(Chainable<Tensor>* chainable);
|
||||||
Expr(Float v);
|
Expr(Float v);
|
||||||
|
|
||||||
Expr operator=(Tensor t) {
|
Expr operator=(Tensor t) {
|
||||||
pimpl_->setVal(t);
|
pimpl_->setVal(t);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor &val();
|
Tensor val();
|
||||||
Tensor grad();
|
Tensor grad();
|
||||||
|
|
||||||
void forward(size_t batchSize);
|
void forward(size_t batchSize);
|
||||||
void backward();
|
void backward();
|
||||||
|
|
||||||
ChainPtr node();
|
ChainPtr node();
|
||||||
operator ChainPtr();
|
operator ChainPtr();
|
||||||
|
|
||||||
std::string Debug() const;
|
std::string Debug() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ChainPtr pimpl_;
|
ChainPtr pimpl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
61
src/sgd.h
61
src/sgd.h
@ -4,41 +4,49 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include "expressions.h"
|
#include "expressions.h"
|
||||||
|
#include "thrust_functions.h"
|
||||||
|
|
||||||
namespace marian {
|
namespace marian {
|
||||||
|
|
||||||
class SGD {
|
class SGD {
|
||||||
public:
|
public:
|
||||||
SGD(Expr& cost_func, Expr& inX, Expr& inY, float eta, std::vector<std::vector<float>> &xData,
|
SGD(Expr& cost_func, Expr& inX, Expr& inY,
|
||||||
std::vector<float> &yData, size_t numClasses, size_t epochs, size_t batchSize)
|
const std::vector<Expr*> params, float eta,
|
||||||
: cost_function_(&cost_func),
|
std::vector<float>& xData, size_t numFeatures,
|
||||||
inX_(&inX),
|
std::vector<float>& yData, size_t numClasses,
|
||||||
inY_(&inY),
|
size_t epochs, size_t batchSize)
|
||||||
eta_(eta),
|
: cost_function_(&cost_func),
|
||||||
xData_(xData),
|
inX_(&inX),
|
||||||
yData_(yData),
|
inY_(&inY),
|
||||||
epochs_(epochs),
|
params_(params),
|
||||||
batchSize_(batchSize),
|
eta_(eta),
|
||||||
numClasses_(numClasses) {}
|
xData_(xData),
|
||||||
|
numFeatures_(numFeatures),
|
||||||
|
yData_(yData),
|
||||||
|
numClasses_(numClasses),
|
||||||
|
epochs_(epochs),
|
||||||
|
batchSize_(batchSize)
|
||||||
|
{}
|
||||||
|
|
||||||
void run() {
|
void Run() {
|
||||||
auto numExamples = xData_[0].size();
|
size_t numExamples = xData_.size()/ numFeatures_;
|
||||||
Tensor xt({(int)batchSize_, (int)numExamples}, 0.0f);
|
Tensor xt({(int)batchSize_, (int)numExamples}, 0.0f);
|
||||||
Tensor yt({(int)batchSize_, (int)numClasses_}, 0.0f);
|
Tensor yt({(int)batchSize_, (int)numClasses_}, 0.0f);
|
||||||
|
|
||||||
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
|
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
|
||||||
std::cerr << "Starting epoch #" << numEpoch << std::endl;
|
std::cerr << "Starting epoch #" << numEpoch << std::endl;
|
||||||
size_t startId = 0;
|
size_t startId = 0;
|
||||||
size_t endId = startId + batchSize_;
|
size_t endId = startId + batchSize_;
|
||||||
|
|
||||||
while (endId < numExamples) {
|
while (endId < numExamples) {
|
||||||
prepareBatch(startId, xt, yt);
|
PrepareBatch(startId, endId, xt, yt);
|
||||||
*inX_ = xt;
|
*inX_ = xt;
|
||||||
*inY_ = yt;
|
*inY_ = yt;
|
||||||
|
|
||||||
cost_function_->forward(batchSize_);
|
cost_function_->forward(batchSize_);
|
||||||
cost_function_->backward();
|
cost_function_->backward();
|
||||||
|
|
||||||
updateModel();
|
UpdateModel();
|
||||||
|
|
||||||
startId += batchSize_;
|
startId += batchSize_;
|
||||||
endId += batchSize_;
|
endId += batchSize_;
|
||||||
@ -46,22 +54,35 @@ class SGD {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void prepareBatch(const size_t index, Tensor& xt, Tensor& yt) {
|
void PrepareBatch(size_t startId, size_t endId, Tensor& xt, Tensor& yt) {
|
||||||
|
std::vector<float> x(xData_.begin() + startId * numFeatures_,
|
||||||
|
xData_.begin() + endId * numFeatures_);
|
||||||
|
std::vector<float> y(yData_.begin() + startId * numClasses_,
|
||||||
|
yData_.begin() + endId * numClasses_);
|
||||||
|
|
||||||
|
xt.Load(x);
|
||||||
|
yt.Load(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
void updateModel() {
|
void UpdateModel() {
|
||||||
|
for (auto& param : params_) {
|
||||||
|
using namespace thrust::placeholders;
|
||||||
|
Element(_1 = _1 - eta_ * _2, param->val(), param->grad());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Expr> cost_function_;
|
std::shared_ptr<Expr> cost_function_;
|
||||||
std::shared_ptr<Expr> inX_;
|
std::shared_ptr<Expr> inX_;
|
||||||
std::shared_ptr<Expr> inY_;
|
std::shared_ptr<Expr> inY_;
|
||||||
|
std::vector<Expr*> params_;
|
||||||
const float eta_;
|
const float eta_;
|
||||||
std::vector<std::vector<float>> &xData_;
|
std::vector<float>& xData_;
|
||||||
std::vector<float> &yData_;
|
const size_t numFeatures_;
|
||||||
|
std::vector<float>& yData_;
|
||||||
|
const size_t numClasses_;
|
||||||
const size_t epochs_;
|
const size_t epochs_;
|
||||||
const size_t batchSize_;
|
const size_t batchSize_;
|
||||||
const size_t numClasses_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace marian
|
} // namespace marian
|
||||||
|
@ -83,6 +83,12 @@ void Tensor::Load(const std::string &path)
|
|||||||
Load(hostData.begin(), hostData.begin());
|
Load(hostData.begin(), hostData.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Tensor::Load(const std::vector<float>& data)
|
||||||
|
{
|
||||||
|
pimpl_->set(data.begin(), data.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void Tensor::Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end)
|
void Tensor::Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end)
|
||||||
{
|
{
|
||||||
pimpl_->set(begin, end);
|
pimpl_->set(begin, end);
|
||||||
|
@ -35,7 +35,7 @@ struct Handles {
|
|||||||
|
|
||||||
const Handles handles;
|
const Handles handles;
|
||||||
|
|
||||||
typedef std::vector<int> Shape;
|
// typedef std::vector<int> Shape;
|
||||||
|
|
||||||
inline std::string Debug(const Shape &shape)
|
inline std::string Debug(const Shape &shape)
|
||||||
{
|
{
|
||||||
@ -199,13 +199,13 @@ class Tensor {
|
|||||||
typedef TensorImpl<Float>::value_type value_type;
|
typedef TensorImpl<Float>::value_type value_type;
|
||||||
|
|
||||||
Tensor() {}
|
Tensor() {}
|
||||||
Tensor(Shape shape, value_type value = 0) {
|
Tensor(const Shape& shape, value_type value = 0) {
|
||||||
allocate(shape, value);
|
allocate(shape, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
~Tensor() {}
|
~Tensor() {}
|
||||||
|
|
||||||
void allocate(Shape shape, value_type value = 0) {
|
void allocate(const Shape& shape, value_type value = 0) {
|
||||||
if(!pimpl_)
|
if(!pimpl_)
|
||||||
pimpl_.reset(new TensorImpl<Float>(shape, value));
|
pimpl_.reset(new TensorImpl<Float>(shape, value));
|
||||||
}
|
}
|
||||||
@ -275,6 +275,7 @@ class Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Load(const std::string &path);
|
void Load(const std::string &path);
|
||||||
|
void Load(const std::vector<float>& data);
|
||||||
void Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end);
|
void Load(const std::vector<float>::const_iterator &begin, const std::vector<float>::const_iterator &end);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
37
src/train_mnist.cu
Normal file
37
src/train_mnist.cu
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
|
||||||
|
#include "marian.h"
|
||||||
|
#include "mnist.h"
|
||||||
|
#include "sgd.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
const size_t IMAGE_SIZE = 784;
|
||||||
|
const size_t LABEL_SIZE = 10;
|
||||||
|
int numofdata;
|
||||||
|
|
||||||
|
vector<float> trainImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||||
|
vector<float>trainLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||||
|
|
||||||
|
using namespace marian;
|
||||||
|
using namespace keywords;
|
||||||
|
|
||||||
|
Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X");
|
||||||
|
Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y");
|
||||||
|
|
||||||
|
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
|
||||||
|
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
|
||||||
|
|
||||||
|
std::vector<Expr*> params;
|
||||||
|
params.push_back(&w);
|
||||||
|
params.push_back(&b);
|
||||||
|
|
||||||
|
auto scores = dot(x, w) + b;
|
||||||
|
auto lr = softmax_fast(scores, axis=1, name="pred");
|
||||||
|
auto cost = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||||
|
cerr << "lr=" << lr.Debug() << endl;
|
||||||
|
|
||||||
|
SGD opt(cost, x, y, params, 0.9, trainImages, IMAGE_SIZE, trainLabels, LABEL_SIZE, 3, 24);
|
||||||
|
opt.Run();
|
||||||
|
return 0;
|
||||||
|
}
|
77
src/validate_mnist.cu
Normal file
77
src/validate_mnist.cu
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
|
||||||
|
#include "marian.h"
|
||||||
|
#include "mnist.h"
|
||||||
|
#include "npz_converter.h"
|
||||||
|
|
||||||
|
using namespace marian;
|
||||||
|
using namespace keywords;
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
const size_t IMAGE_SIZE = 784;
|
||||||
|
const size_t LABEL_SIZE = 10;
|
||||||
|
int numofdata;
|
||||||
|
|
||||||
|
std::cerr << "Loading test set...";
|
||||||
|
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||||
|
std::vector<float>testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||||
|
std::cerr << "\tDone." << std::endl;
|
||||||
|
|
||||||
|
std::cerr << "Loading model params...";
|
||||||
|
NpzConverter converter("../scripts/test_model/model.npz");
|
||||||
|
|
||||||
|
std::vector<float> wData;
|
||||||
|
Shape wShape;
|
||||||
|
converter.Load("weights", wData, wShape);
|
||||||
|
|
||||||
|
std::vector<float> bData;
|
||||||
|
Shape bShape;
|
||||||
|
converter.Load("bias", bData, bShape);
|
||||||
|
|
||||||
|
auto initW = [&wData](Tensor t) {
|
||||||
|
thrust::copy(wData.begin(), wData.end(), t.begin());
|
||||||
|
};
|
||||||
|
|
||||||
|
auto initB = [&bData](Tensor t) {
|
||||||
|
thrust::copy(bData.begin(), bData.end(), t.begin());
|
||||||
|
};
|
||||||
|
|
||||||
|
std::cerr << "\tDone." << std::endl;
|
||||||
|
|
||||||
|
|
||||||
|
Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X");
|
||||||
|
|
||||||
|
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW);
|
||||||
|
Expr b = param(shape={1, LABEL_SIZE}, name="b0", init=initB);
|
||||||
|
|
||||||
|
std::cerr << "Building model...";
|
||||||
|
auto scores = dot(x, w) + b;
|
||||||
|
auto predict = softmax(scores, axis=1, name="pred");
|
||||||
|
std::cerr << "\tDone." << std::endl;
|
||||||
|
|
||||||
|
Tensor xt({numofdata, IMAGE_SIZE});
|
||||||
|
xt.Load(testImages);
|
||||||
|
|
||||||
|
predict.forward(numofdata);
|
||||||
|
|
||||||
|
auto results = predict.val();
|
||||||
|
|
||||||
|
size_t acc = 0;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
||||||
|
size_t correct = 0;
|
||||||
|
size_t predicted = 0;
|
||||||
|
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||||
|
if (testLabels[i+j]) correct = j;
|
||||||
|
if (results[i + j] > results[i + predicted]) predicted = j;
|
||||||
|
}
|
||||||
|
acc += (correct == predicted);
|
||||||
|
std::cerr << "corect: " << correct << " | " << predicted << "(";
|
||||||
|
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||||
|
std::cerr << results[i+j] << " ";
|
||||||
|
}
|
||||||
|
std::cerr << std::endl;
|
||||||
|
}
|
||||||
|
std::cerr << "ACC: " << float(acc)/numofdata << std::endl;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user