sgd variants

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-16 17:43:29 +02:00
parent 7c63606b75
commit 15429dd88f
12 changed files with 99 additions and 214 deletions

View File

@ -5,7 +5,6 @@ cuda_add_library(marian_lib
cnpy/cnpy.cpp
exception.cpp
expression_graph.cu
sgd.cu
tensor.cu
tensor_operators.cu
expression_operators.cu

View File

@ -39,12 +39,12 @@ std::string Expr::Debug() const
}
///////////////////////////////////////////////////////
ExpressionGraph::ExpressionGraph(int cudaDevice)
: stack_(new ChainableStack)
{
std::srand (time(NULL));
cudaSetDevice(0);
}
//ExpressionGraph::ExpressionGraph(int cudaDevice)
//: stack_(new ChainableStack)
//{
// std::srand (time(NULL));
// cudaSetDevice(0);
//
//}
}

View File

@ -38,9 +38,14 @@ class Expr {
class ExpressionGraph {
public:
ExpressionGraph(int cudaDevice);
ExpressionGraph() : stack_(new ChainableStack) {}
void forward(size_t batchSize) {
void backprop(int batchSize) {
forward(batchSize);
backward();
}
void forward(int batchSize) {
for(auto&& v : *stack_) {
v->allocate(batchSize);
}
@ -48,6 +53,16 @@ class ExpressionGraph {
v->forward();
}
void backward() {
for(auto&& v : *stack_)
v->set_zero_adjoint();
typedef typename ChainableStack::reverse_iterator It;
stack_->back()->init_dependent();
for(It it = stack_->rbegin(); it != stack_->rend(); ++it)
(*it)->backward();
}
std::string graphviz() {
std::stringstream ss;
ss << "digraph ExpressionGraph {" << std::endl;
@ -60,19 +75,13 @@ class ExpressionGraph {
return ss.str();
}
void backward() {
for(auto&& v : *stack_)
v->set_zero_adjoint();
typedef typename ChainableStack::reverse_iterator It;
stack_->back()->init_dependent();
for(It it = stack_->rbegin(); it != stack_->rend(); ++it)
(*it)->backward();
}
/*********************************************************/
template <typename ...Args>
inline Expr input(Args ...args) {
return Expr(this, new InputNode(args...));
Expr e(this, new InputNode(args...));
inputs_.emplace_back(e);
return e;
}
template <typename ...Args>
@ -117,14 +126,20 @@ class ExpressionGraph {
named_.emplace(name, e);
}
std::vector<Expr>& inputs() {
return inputs_;
}
std::vector<Expr>& params() {
return params_;
}
private:
ChainableStackPtr stack_;
std::map<std::string, Expr> named_;
std::vector<Expr> params_;
std::vector<Expr> inputs_;
};
}

View File

@ -1,140 +0,0 @@
#include <ctime>
#include <algorithm>
#include <vector>
#include "sgd.h"
#include "thrust_functions.h"
using namespace std;
namespace marian {
SGD::SGD(ExpressionGraph& g, float eta,
std::vector<float>& xData, size_t numFeatures,
std::vector<float>& yData, size_t numClasses,
size_t epochs, size_t batchSize)
: graph_(g),
eta_(eta),
xData_(xData),
numFeatures_(numFeatures),
yData_(yData),
numClasses_(numClasses),
epochs_(epochs),
maxBatchSize_(batchSize)
{}
void SGD::Run()
{
size_t numExamples = xData_.size()/ numFeatures_;
Tensor xt({(int)maxBatchSize_, (int)numExamples}, 0.0f);
Tensor yt({(int)maxBatchSize_, (int)numClasses_}, 0.0f);
vector<size_t> shuffle = CreateShuffle(numExamples);
//vector<size_t> shuffle;
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
std::cerr << "Starting epoch #" << numEpoch << std::endl;
size_t startId = 0;
while (startId < numExamples) {
size_t batchSize = std::min(maxBatchSize_, numExamples - startId);
size_t endId = startId + batchSize;
PrepareBatch(startId, endId, batchSize, shuffle, xt, yt);
graph_["x"] = xt;
graph_["y"] = yt;
graph_.forward(maxBatchSize_);
graph_.backward();
UpdateModel();
startId += maxBatchSize_;
}
}
}
std::vector<size_t> SGD::CreateShuffle(size_t numExamples) const {
vector<size_t> ret(numExamples);
std::iota(ret.begin(), ret.end(), 0);
std::random_shuffle ( ret.begin(), ret.end() );
/*
cerr << "shuffled" << endl;
for (size_t i = 0; i < ret.size(); ++i) {
cerr << ret[i] << " ";
}
*/
return ret;
}
void SGD::PrepareBatch(
size_t startId,
size_t endId,
size_t batchSize,
const std::vector<size_t> &shuffle,
Tensor& xt,
Tensor& yt) {
/*
std::vector<float> x(xData_.begin() + startId * numFeatures_,
xData_.begin() + endId * numFeatures_);
std::vector<float> y(yData_.begin() + startId * numClasses_,
yData_.begin() + endId * numClasses_);
*/
std::vector<float> x(batchSize * numFeatures_);
std::vector<float> y(batchSize * numClasses_);
//cerr << "batchSize=" << batchSize << endl;
/*
cerr << "startId=" << startId
<< " " << endId
<< " " << batchSize
<< endl;
cerr << "numExamples=" << shuffle.size() << endl;
cerr << "numFeatures_=" << numFeatures_ << " " << numClasses_ << endl;
cerr << "sizes=" << x.size()
<< " " << y.size()
<< " " << xData_.size()
<< " " << yData_.size()
<< endl;
*/
size_t startXId = 0;
size_t startYId = 0;
for (size_t i = startId; i < endId; ++i) {
size_t ind = shuffle[i];
size_t startXDataId = ind * numFeatures_;
size_t startYDataId = ind * numClasses_;
size_t endXDataId = startXDataId + numFeatures_;
size_t endYDataId = startYDataId + numClasses_;
/*
cerr << "i=" << i
<< " " << ind
<< " " << startXDataId << "-" << endXDataId
<< " " << startYDataId << "-" << endYDataId
<< endl;
*/
std::copy(xData_.begin() + startXDataId,
xData_.begin() + endXDataId,
x.begin() + startXId);
std::copy(yData_.begin() + startYDataId,
yData_.begin() + endYDataId,
y.begin() + startYId);
startXId += numFeatures_;
startYId += numClasses_;
}
xt.set(x);
yt.set(y);
}
void SGD::UpdateModel() {
for (auto& param : graph_.params()) {
using namespace thrust::placeholders;
Element(_1 -= eta_ * _2, param.val(), param.grad());
}
}
} // namespace

View File

@ -1,43 +1,48 @@
#pragma once
#include <memory>
#include <iostream>
#include "expression_graph.h"
#include "thrust_functions.h"
#include <map>
#include <boost/any.hpp>
#include "tensor_operators.h"
namespace marian {
class SGD {
class Sgd {
public:
SGD(ExpressionGraph& g, float eta,
std::vector<float>& xData, size_t numFeatures,
std::vector<float>& yData, size_t numClasses,
size_t epochs, size_t batchSize);
void Run();
Sgd(float eta=0.1) : eta_(eta) {}
void operator()(ExpressionGraph& graph, int batchSize) {
graph.backprop(batchSize);
for(auto& param : graph.params())
Element(_1 -= eta_ * _2, param.val(), param.grad());
}
private:
ExpressionGraph& graph_;
const float eta_;
std::vector<float>& xData_;
const size_t numFeatures_;
std::vector<float>& yData_;
const size_t numClasses_;
const size_t epochs_;
const size_t maxBatchSize_;
std::vector<size_t> CreateShuffle(size_t numExamples) const;
void PrepareBatch(
size_t startId,
size_t endId,
size_t batchSize,
const std::vector<size_t> &shuffle,
Tensor& xt,
Tensor& yt);
void UpdateModel();
float eta_;
};
} // namespace marian
class Adagrad {
public:
Adagrad(float eta=0.1) : eta_(eta) {}
void operator()(ExpressionGraph& graph, int batchSize) {
if(history_.size() < graph.params().size())
for(auto& param : graph.params())
history_.emplace_back(Tensor(param.grad().shape(), 0));
graph.backprop(batchSize);
auto it = history_.begin();
for(auto& param : graph.params()) {
Element(_1 -= eta_ / Sqrt(_2) * _3, param.val(), *it, param.grad());
Element(_1 += _2 * _2, *it, param.grad());
it++;
}
}
private:
float eta_;
std::vector<Tensor> history_;
};
}

View File

@ -1,8 +1,6 @@
#include <fstream>
#include "tensor.h"
using namespace std;
namespace marian {
void Tensor::set(const std::vector<float>& data)

View File

@ -21,7 +21,7 @@ int main(int argc, char** argv) {
std::vector<Expr> Y;
std::vector<Expr> H;
ExpressionGraph g(0);
ExpressionGraph g;
for (int t = 0; t < num_inputs; ++t) {
X.emplace_back(g.input(shape={batch_size, input_size}));
@ -39,10 +39,9 @@ int main(int argc, char** argv) {
string sourceLine, targetLine;
while (getline(sourceFile, sourceLine)) {
getline(targetFile, targetLine);
std::vector<size_t> sourceIds = sourceVocab.ProcessSentence(sourceLine);
std::vector<size_t> targetIds = sourceVocab.ProcessSentence(targetLine);
getline(targetFile, targetLine);
std::vector<size_t> sourceIds = sourceVocab.ProcessSentence(sourceLine);
std::vector<size_t> targetIds = sourceVocab.ProcessSentence(targetLine);
}
std::cerr << "Building RNN..." << std::endl;

View File

@ -85,6 +85,19 @@ namespace thrust
return compose(unary_operator<unary_tanh>(), _1);
}
template<typename T>
struct unary_sqrt : public thrust::unary_function<T,T> {
__host__ __device__
T operator()(const T &x) const { return sqrtf(x); }
};
template<typename Eval>
__host__ __device__
actor<composite<unary_operator<unary_sqrt>, actor<Eval>>>
Sqrt(const actor<Eval> &_1) {
return compose(unary_operator<unary_sqrt>(), _1);
}
template<typename T1, typename T2>
__host__ __device__
actor<composite<binary_operator<thrust::maximum>, actor<T1>, actor<T2>>>

View File

@ -11,12 +11,12 @@ int main(int argc, char** argv) {
int numofdata;
vector<float> trainImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
vector<float>trainLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
vector<float> trainLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
using namespace marian;
using namespace keywords;
ExpressionGraph g(0);
ExpressionGraph g;
Expr x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
Expr y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
@ -24,16 +24,13 @@ int main(int argc, char** argv) {
Expr w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE}), "w");
Expr b = named(g.param(shape={1, LABEL_SIZE}), "b");
std::vector<Expr*> params;
params.push_back(&w);
params.push_back(&b);
auto scores = dot(x, w) + b;
auto lr = softmax_fast(scores);
auto cost = named(-mean(sum(y * log(lr), axis=1), axis=0), "cost");
cerr << "lr=" << lr.Debug() << endl;
SGD opt(g, 0.9, trainImages, IMAGE_SIZE, trainLabels, LABEL_SIZE, 3, 24);
opt.Run();
Adagrad opt;
opt(g, 300);
return 0;
}

View File

@ -32,7 +32,7 @@ int main(int argc, char** argv) {
int num_inputs = 8;
int num_outputs = 6;
ExpressionGraph g(0);
ExpressionGraph g;
std::vector<Expr*> X(num_inputs+1); // For the stop symbol.
std::vector<Expr*> Y(num_outputs);
std::vector<Expr*> H(num_inputs+1); // For the stop symbol.

View File

@ -10,7 +10,7 @@ const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
int BATCH_SIZE = 10000;
ExpressionGraph build_graph(int cudaDevice) {
ExpressionGraph build_graph() {
std::cerr << "Loading model params...";
NpzConverter converter("../scripts/test_model_single/model.npz");
@ -22,7 +22,7 @@ ExpressionGraph build_graph(int cudaDevice) {
std::cerr << "Building model...";
ExpressionGraph g(cudaDevice);
ExpressionGraph g;
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
@ -52,7 +52,7 @@ int main(int argc, char** argv) {
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
std::cerr << "Done." << std::endl;
ExpressionGraph g = build_graph(0);
ExpressionGraph g = build_graph();
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
Tensor yt({BATCH_SIZE, LABEL_SIZE});

View File

@ -56,8 +56,7 @@ int main(int argc, char** argv) {
std::cerr << "\tDone." << std::endl;
ExpressionGraph g(0);
ExpressionGraph g;
auto x = g.input(shape={whatevs, IMAGE_SIZE}, name="X");
auto y = g.input(shape={whatevs, LABEL_SIZE}, name="Y");