This commit is contained in:
Andre Martins 2016-09-16 18:36:30 +01:00
commit 8f1a7382ec
20 changed files with 2508 additions and 239 deletions

View File

@ -20,3 +20,14 @@ endif(Boost_FOUND)
include_directories(${marian_SOURCE_DIR}/src)
add_subdirectory(src)
# add a target to generate API documentation with Doxygen
find_package(Doxygen)
if(DOXYGEN_FOUND)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/Doxyfile.in ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile @ONLY)
add_custom_target(doc
${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
COMMENT "Generating API documentation with Doxygen" VERBATIM
)
endif(DOXYGEN_FOUND)

2303
Doxyfile.in Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,3 @@
The MIT License (MIT)
Copyright (c) 2016 Marcin Junczys-Dowmunt
Permission is hereby granted, free of charge, to any person obtaining a copy
@ -9,8 +7,8 @@ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

View File

@ -29,3 +29,6 @@ Compilation with `cmake > 3.5`:
cmake ..
make -j
To compile API documentation using Doxygen, first cd to the build directory, and then:
make doc

View File

@ -5,7 +5,6 @@ cuda_add_library(marian_lib
cnpy/cnpy.cpp
exception.cpp
expression_graph.cu
sgd.cu
tensor.cu
tensor_operators.cu
expression_operators.cu

View File

@ -1,8 +1,6 @@
#include <sstream>
#include "expression_graph.h"
using namespace std;
namespace marian {
Expr::Expr(ExpressionGraphPtr g, Chainable<Tensor>* chainable)
@ -32,19 +30,10 @@ Expr::operator ChainPtr() {
std::string Expr::Debug() const
{
stringstream strm;
std::stringstream strm;
const Shape &shape = pimpl_->shape();
strm << marian::Debug(shape);
return strm.str();
}
///////////////////////////////////////////////////////
ExpressionGraph::ExpressionGraph(int cudaDevice)
: stack_(new ChainableStack)
{
std::srand (time(NULL));
cudaSetDevice(0);
}
}

View File

@ -38,9 +38,14 @@ class Expr {
class ExpressionGraph {
public:
ExpressionGraph(int cudaDevice);
ExpressionGraph() : stack_(new ChainableStack) {}
void forward(size_t batchSize) {
void backprop(int batchSize) {
forward(batchSize);
backward();
}
void forward(int batchSize) {
for(auto&& v : *stack_) {
v->allocate(batchSize);
}
@ -48,6 +53,16 @@ class ExpressionGraph {
v->forward();
}
void backward() {
for(auto&& v : *stack_)
v->set_zero_adjoint();
typedef typename ChainableStack::reverse_iterator It;
stack_->back()->init_dependent();
for(It it = stack_->rbegin(); it != stack_->rend(); ++it)
(*it)->backward();
}
std::string graphviz() {
std::stringstream ss;
ss << "digraph ExpressionGraph {" << std::endl;
@ -60,19 +75,13 @@ class ExpressionGraph {
return ss.str();
}
void backward() {
for(auto&& v : *stack_)
v->set_zero_adjoint();
typedef typename ChainableStack::reverse_iterator It;
stack_->back()->init_dependent();
for(It it = stack_->rbegin(); it != stack_->rend(); ++it)
(*it)->backward();
}
/*********************************************************/
template <typename ...Args>
inline Expr input(Args ...args) {
return Expr(this, new InputNode(args...));
Expr e(this, new InputNode(args...));
inputs_.emplace_back(e);
return e;
}
template <typename ...Args>
@ -117,14 +126,20 @@ class ExpressionGraph {
named_.emplace(name, e);
}
std::vector<Expr>& inputs() {
return inputs_;
}
std::vector<Expr>& params() {
return params_;
}
private:
ChainableStackPtr stack_;
std::map<std::string, Expr> named_;
std::vector<Expr> params_;
std::vector<Expr> inputs_;
};
}

View File

@ -29,7 +29,7 @@ Expr operator-(Expr a) {
return Expr(a.graph(), new NegNodeOp(a));
};
Expr softmax_fast(Expr a) {
Expr softmax(Expr a) {
return Expr(a.graph(), new SoftmaxNodeOp(a));
}

View File

@ -72,12 +72,12 @@ inline Expr sum(Expr a, Args ...args) {
// inefficient
template <typename ...Args>
Expr softmax(Expr a, Args ...args) {
Expr softmax_slow(Expr a, Args ...args) {
Expr e = exp(a);
return e / sum(e, args...);
}
Expr softmax_fast(Expr a);
Expr softmax(Expr a);
// inefficient
template <typename ...Args>

View File

@ -1,140 +0,0 @@
#include <ctime>
#include <algorithm>
#include <vector>
#include "sgd.h"
#include "thrust_functions.h"
using namespace std;
namespace marian {
SGD::SGD(ExpressionGraph& g, float eta,
std::vector<float>& xData, size_t numFeatures,
std::vector<float>& yData, size_t numClasses,
size_t epochs, size_t batchSize)
: graph_(g),
eta_(eta),
xData_(xData),
numFeatures_(numFeatures),
yData_(yData),
numClasses_(numClasses),
epochs_(epochs),
maxBatchSize_(batchSize)
{}
void SGD::Run()
{
size_t numExamples = xData_.size()/ numFeatures_;
Tensor xt({(int)maxBatchSize_, (int)numExamples}, 0.0f);
Tensor yt({(int)maxBatchSize_, (int)numClasses_}, 0.0f);
vector<size_t> shuffle = CreateShuffle(numExamples);
//vector<size_t> shuffle;
for (size_t numEpoch = 0; numEpoch < epochs_; ++numEpoch) {
std::cerr << "Starting epoch #" << numEpoch << std::endl;
size_t startId = 0;
while (startId < numExamples) {
size_t batchSize = std::min(maxBatchSize_, numExamples - startId);
size_t endId = startId + batchSize;
PrepareBatch(startId, endId, batchSize, shuffle, xt, yt);
graph_["x"] = xt;
graph_["y"] = yt;
graph_.forward(maxBatchSize_);
graph_.backward();
UpdateModel();
startId += maxBatchSize_;
}
}
}
std::vector<size_t> SGD::CreateShuffle(size_t numExamples) const {
vector<size_t> ret(numExamples);
std::iota(ret.begin(), ret.end(), 0);
std::random_shuffle ( ret.begin(), ret.end() );
/*
cerr << "shuffled" << endl;
for (size_t i = 0; i < ret.size(); ++i) {
cerr << ret[i] << " ";
}
*/
return ret;
}
void SGD::PrepareBatch(
size_t startId,
size_t endId,
size_t batchSize,
const std::vector<size_t> &shuffle,
Tensor& xt,
Tensor& yt) {
/*
std::vector<float> x(xData_.begin() + startId * numFeatures_,
xData_.begin() + endId * numFeatures_);
std::vector<float> y(yData_.begin() + startId * numClasses_,
yData_.begin() + endId * numClasses_);
*/
std::vector<float> x(batchSize * numFeatures_);
std::vector<float> y(batchSize * numClasses_);
//cerr << "batchSize=" << batchSize << endl;
/*
cerr << "startId=" << startId
<< " " << endId
<< " " << batchSize
<< endl;
cerr << "numExamples=" << shuffle.size() << endl;
cerr << "numFeatures_=" << numFeatures_ << " " << numClasses_ << endl;
cerr << "sizes=" << x.size()
<< " " << y.size()
<< " " << xData_.size()
<< " " << yData_.size()
<< endl;
*/
size_t startXId = 0;
size_t startYId = 0;
for (size_t i = startId; i < endId; ++i) {
size_t ind = shuffle[i];
size_t startXDataId = ind * numFeatures_;
size_t startYDataId = ind * numClasses_;
size_t endXDataId = startXDataId + numFeatures_;
size_t endYDataId = startYDataId + numClasses_;
/*
cerr << "i=" << i
<< " " << ind
<< " " << startXDataId << "-" << endXDataId
<< " " << startYDataId << "-" << endYDataId
<< endl;
*/
std::copy(xData_.begin() + startXDataId,
xData_.begin() + endXDataId,
x.begin() + startXId);
std::copy(yData_.begin() + startYDataId,
yData_.begin() + endYDataId,
y.begin() + startYId);
startXId += numFeatures_;
startYId += numClasses_;
}
xt.set(x);
yt.set(y);
}
void SGD::UpdateModel() {
for (auto& param : graph_.params()) {
using namespace thrust::placeholders;
Element(_1 -= eta_ * _2, param.val(), param.grad());
}
}
} // namespace

View File

@ -1,43 +1,51 @@
#pragma once
#include <memory>
#include <iostream>
#include "expression_graph.h"
#include "thrust_functions.h"
#include <map>
#include <boost/any.hpp>
#include "tensor_operators.h"
namespace marian {
class SGD {
class Sgd {
public:
SGD(ExpressionGraph& g, float eta,
std::vector<float>& xData, size_t numFeatures,
std::vector<float>& yData, size_t numClasses,
size_t epochs, size_t batchSize);
void Run();
Sgd(float eta=0.1) : eta_(eta) {}
void operator()(ExpressionGraph& graph, int batchSize) {
graph.backprop(batchSize);
for(auto& param : graph.params())
Element(_1 -= eta_ * _2,
param.val(), param.grad());
}
private:
ExpressionGraph& graph_;
const float eta_;
std::vector<float>& xData_;
const size_t numFeatures_;
std::vector<float>& yData_;
const size_t numClasses_;
const size_t epochs_;
const size_t maxBatchSize_;
std::vector<size_t> CreateShuffle(size_t numExamples) const;
void PrepareBatch(
size_t startId,
size_t endId,
size_t batchSize,
const std::vector<size_t> &shuffle,
Tensor& xt,
Tensor& yt);
void UpdateModel();
float eta_;
};
} // namespace marian
class Adagrad {
public:
Adagrad(float eta=0.1) : eta_(eta) {}
void operator()(ExpressionGraph& graph, int batchSize) {
float fudgeFactor = 1e-6;
graph.backprop(batchSize);
if(history_.size() < graph.params().size())
for(auto& param : graph.params())
history_.emplace_back(Tensor(param.grad().shape(), 0));
auto it = history_.begin();
for(auto& param : graph.params()) {
Element(_1 += _2 * _2, *it, param.grad());
Element(_1 -= eta_ / (fudgeFactor + Sqrt(_2)) * _3,
param.val(), *it, param.grad());
it++;
}
}
private:
float eta_;
std::vector<Tensor> history_;
};
}

View File

@ -1,8 +1,6 @@
#include <fstream>
#include "tensor.h"
using namespace std;
namespace marian {
void Tensor::set(const std::vector<float>& data)

View File

@ -1,5 +1,7 @@
#include "tensor_operators.h"
using namespace std;
namespace marian {
__global__ void gSubtractMean(float* out, float* weights,
@ -53,6 +55,7 @@ void SubtractMean(Tensor* Out, Tensor &Weights) {
cudaStreamSynchronize(0);
}
///////////////////////////////////////////////////////
__global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
@ -97,6 +100,35 @@ void Softmax(Tensor* Out) {
gSoftMax<<<blocks, threads, shared>>>(Out->data(), m, k);
cudaStreamSynchronize(0);
}
///////////////////////////////////////////////////////
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols) {
size_t row = blockIdx.x;
size_t startInd = row * cols;
float maxScore = -99999;
size_t maxInd;
for (size_t col = 0; col < cols; ++col) {
size_t ind = startInd + col;
float score = data[ind];
if (score > maxScore) {
maxScore = score;
maxInd = col;
}
}
out[row] = maxInd;
}
void Argmax(Tensor* Out, const Tensor* In) {
size_t m = In->shape()[0];
size_t k = In->shape()[1];
int blocks = m; //std::min(MAX_BLOCKS, (int) m);
int threads = k; //std::min(MAX_THREADS, (int) k);
//int shared = sizeof(float) * threads * 2;
gArgMax<<<blocks, threads>>>(Out->data(), In->data(), m, k);
cudaStreamSynchronize(0);
}
///////////////////////////////////////////////////////
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, Float beta) {
@ -137,4 +169,4 @@ Tensor Prod(Tensor C, const Tensor A, const Tensor B,
return temp;
}
}
}

View File

@ -151,6 +151,10 @@ __global__ void gSoftMax(float* softMaxP, size_t rows, size_t cols);
void Softmax(Tensor* Out);
__global__ void gArgMax(float *out, const float *data, size_t rows, size_t cols);
void Argmax(Tensor* Out, const Tensor* In);
Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
bool transA, bool transB, Float beta);

View File

@ -3,7 +3,51 @@
#include "mnist.h"
#include "vocab.h"
#include "tensor_operators.h"
using namespace std;
///////////////////////////////////////////////////////
string output(const std::vector<float> &vec)
{
stringstream strm;
for (size_t i = 0; i < vec.size(); ++i) {
strm << vec[i] << " ";
}
return strm.str();
}
void testArgMax()
{
using namespace std;
using namespace marian;
std::vector<float> hVec({29,19, 49,39, 79,99, 79,39});
cerr << "hVec =" << output(hVec) << endl;
thrust::device_vector<float> dVec(8);
thrust::copy(hVec.begin(), hVec.end(), dVec.begin());
float *data = thrust::raw_pointer_cast(dVec.data());
thrust::device_vector<float> dLabel(4);
float *labelPtr = thrust::raw_pointer_cast(dLabel.data());
gArgMax<<<4, 1, sizeof(float)>>>(labelPtr, data, 4, 2);
std::vector<float> hVec2(8);
thrust::copy(dVec.begin(), dVec.end(), hVec2.begin());
cerr << "hVec2=" << output(hVec2) << endl;
std::vector<float> hLabel(4);
thrust::copy(dLabel.begin(), dLabel.end(), hLabel.begin());
cerr << "hLabel=" << output(hLabel) << endl;
exit(0);
}
///////////////////////////////////////////////////////
int main(int argc, char** argv) {
//testArgMax();
using namespace std;
using namespace marian;
@ -21,7 +65,7 @@ int main(int argc, char** argv) {
std::vector<Expr> Y;
std::vector<Expr> H;
ExpressionGraph g(0);
ExpressionGraph g;
for (int t = 0; t < num_inputs; ++t) {
X.emplace_back(g.input(shape={batch_size, input_size}));
@ -39,10 +83,9 @@ int main(int argc, char** argv) {
string sourceLine, targetLine;
while (getline(sourceFile, sourceLine)) {
getline(targetFile, targetLine);
std::vector<size_t> sourceIds = sourceVocab.ProcessSentence(sourceLine);
std::vector<size_t> targetIds = sourceVocab.ProcessSentence(targetLine);
getline(targetFile, targetLine);
std::vector<size_t> sourceIds = sourceVocab.ProcessSentence(sourceLine);
std::vector<size_t> targetIds = sourceVocab.ProcessSentence(targetLine);
}
std::cerr << "Building RNN..." << std::endl;
@ -57,10 +100,10 @@ int main(int argc, char** argv) {
std::cerr << "Building output layer..." << std::endl;
std::vector<Expr> Yp;
Yp.emplace_back(softmax_fast(dot(H[0], Why) + by));
Yp.emplace_back(softmax(dot(H[0], Why) + by));
Expr cross_entropy = sum(Y[0] * log(Yp[0]), axis=1);
for (int t = 1; t < num_inputs; ++t) {
Yp.emplace_back(softmax_fast(dot(H[t], Why) + by));
Yp.emplace_back(softmax(dot(H[t], Why) + by));
cross_entropy = cross_entropy + sum(Y[t] * log(Yp[t]), axis=1);
}
auto graph = -mean(cross_entropy, axis=0, name="cost");

View File

@ -85,6 +85,19 @@ namespace thrust
return compose(unary_operator<unary_tanh>(), _1);
}
template<typename T>
struct unary_sqrt : public thrust::unary_function<T,T> {
__host__ __device__
T operator()(const T &x) const { return sqrtf(x); }
};
template<typename Eval>
__host__ __device__
actor<composite<unary_operator<unary_sqrt>, actor<Eval>>>
Sqrt(const actor<Eval> &_1) {
return compose(unary_operator<unary_sqrt>(), _1);
}
template<typename T1, typename T2>
__host__ __device__
actor<composite<binary_operator<thrust::maximum>, actor<T1>, actor<T2>>>

View File

@ -11,12 +11,12 @@ int main(int argc, char** argv) {
int numofdata;
vector<float> trainImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
vector<float>trainLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
vector<float> trainLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
using namespace marian;
using namespace keywords;
ExpressionGraph g(0);
ExpressionGraph g;
Expr x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
Expr y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
@ -24,16 +24,13 @@ int main(int argc, char** argv) {
Expr w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE}), "w");
Expr b = named(g.param(shape={1, LABEL_SIZE}), "b");
std::vector<Expr*> params;
params.push_back(&w);
params.push_back(&b);
auto scores = dot(x, w) + b;
auto lr = softmax_fast(scores);
auto lr = softmax(scores);
auto cost = named(-mean(sum(y * log(lr), axis=1), axis=0), "cost");
cerr << "lr=" << lr.Debug() << endl;
SGD opt(g, 0.9, trainImages, IMAGE_SIZE, trainLabels, LABEL_SIZE, 3, 24);
opt.Run();
Adagrad opt;
opt(g, 300);
return 0;
}

View File

@ -7,8 +7,7 @@
using namespace marian;
using namespace keywords;
ExpressionGraph build_graph(int cuda_device,
int source_vocabulary_size,
ExpressionGraph build_graph(int source_vocabulary_size,
int target_vocabulary_size,
int embedding_size,
int hidden_size,
@ -21,7 +20,7 @@ ExpressionGraph build_graph(int cuda_device,
int num_inputs = num_source_tokens;
int num_outputs = num_target_tokens;
ExpressionGraph g(cuda_device);
ExpressionGraph g;
std::vector<Expr> X, Y, H, S;
// We're including the stop symbol here.
@ -83,10 +82,10 @@ ExpressionGraph build_graph(int cuda_device,
// Softmax layer and cost function.
std::vector<Expr> Yp;
Yp.emplace_back(named(softmax_fast(dot(h0_d, Why) + by), "pred"));
Yp.emplace_back(named(softmax(dot(h0_d, Why) + by), "pred"));
Expr cross_entropy = sum(Y[0] * log(Yp[0]), axis=1);
for (int t = 1; t <= num_outputs; ++t) {
Yp.emplace_back(named(softmax_fast(dot(S[t-1], Why) + by), "pred"));
Yp.emplace_back(named(softmax(dot(S[t-1], Why) + by), "pred"));
cross_entropy = cross_entropy + sum(Y[t] * log(Yp[t]), axis=1);
}
auto cost = named(-mean(cross_entropy, axis=0), "cost");
@ -153,8 +152,7 @@ int main(int argc, char** argv) {
// Build the encoder-decoder computation graph.
int embedding_size = 50;
int hidden_size = 100;
ExpressionGraph g = build_graph(0, // cuda device.
source_vocab.Size(),
ExpressionGraph g = build_graph(source_vocab.Size(),
target_vocab.Size(),
embedding_size,
hidden_size,
@ -253,7 +251,6 @@ int main(int argc, char** argv) {
ss << "Y" << t;
g[ss.str()] = Yt;
}
#endif
std::cerr << "Printing the computation graph..." << std::endl;

View File

@ -10,7 +10,7 @@ const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
int BATCH_SIZE = 10000;
ExpressionGraph build_graph(int cudaDevice) {
ExpressionGraph build_graph() {
std::cerr << "Loading model params...";
NpzConverter converter("../scripts/test_model_single/model.npz");
@ -22,7 +22,7 @@ ExpressionGraph build_graph(int cudaDevice) {
std::cerr << "Building model...";
ExpressionGraph g(cudaDevice);
ExpressionGraph g;
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
@ -32,7 +32,7 @@ ExpressionGraph build_graph(int cudaDevice) {
init=from_vector(bData)), "b");
auto probs = named(
softmax_fast(dot(x, w) + b), //, axis=1),
softmax(dot(x, w) + b), //, axis=1),
"probs"
);
@ -52,7 +52,7 @@ int main(int argc, char** argv) {
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
std::cerr << "Done." << std::endl;
ExpressionGraph g = build_graph(0);
ExpressionGraph g = build_graph();
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
Tensor yt({BATCH_SIZE, LABEL_SIZE});

View File

@ -56,8 +56,7 @@ int main(int argc, char** argv) {
std::cerr << "\tDone." << std::endl;
ExpressionGraph g(0);
ExpressionGraph g;
auto x = g.input(shape={whatevs, IMAGE_SIZE}, name="X");
auto y = g.input(shape={whatevs, LABEL_SIZE}, name="Y");
@ -69,7 +68,7 @@ int main(int argc, char** argv) {
std::cerr << "Building model...";
auto layer1 = tanh(dot(x, w1) + b1);
auto layer2 = softmax(dot(layer1, w2) + b2, axis=1, name="layer2");
auto layer2 = softmax(dot(layer1, w2) + b2);
auto predict = layer2;
std::cerr << "Done." << std::endl;