This commit is contained in:
Hieu Hoang 2016-09-15 17:08:53 +01:00
commit 5d114049cf
7 changed files with 150 additions and 117 deletions

View File

@ -11,14 +11,15 @@ Installation
Requirements:
* g++ with C++11
* g++ with c++11
* CUDA and CuDNN
* Boost (>= 1.56)
Exporting some paths for CuDNN may be required (put it, for example, in your `.bashrc` file):
export PATH=$PATH:$HOME/.local/bin:/usr/local/cuda/bin
export LIBRARY_PATH=$LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cudnn-5/lib64
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cudnn-5/lib64
export LIBRARY_PATH=$LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:/usr/local/cudnn-5/lib64
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:/usr/local/cudnn-5/lib64
export CPATH=$CPATH:/usr/local/cudnn-5/include
Compilation with `cmake > 3.5`:

View File

@ -153,20 +153,6 @@ inline Expr sum(Expr a, Args ...args) {
template <typename ...Args>
inline Expr softmax(Expr a, Args ...args) {
Expr e = exp(a);
#if 0
ChainPtr n = a.node();
auto print_shape = [n]() -> Shape {
std::cerr << "Shape: ";
for (auto val : n->val().shape()) {
std::cerr << val << " ";
}
std::cerr << std::endl;
return {1,1};
};
using namespace keywords;
Expr one = ones(shape={1, 1}, lazy_shape=print_shape);
#endif
return e / sum(e, args...);
}

View File

@ -162,11 +162,10 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
// For each row, the Jacobian times vector is given by:
// J * dy = p .* (dy - avg*1)
// where avg = p'*dy and p is the softmax output (probabilities).
Tensor result = adj_;
Tensor result(adj_.shape());
thrust::copy(adj_.begin(), adj_.end(), result.begin());
SubtractMean(&result, val_);
// beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
// to sum gradients from different graph parts.
Prod(a_->grad(), adj_, result, false, false, 1.0);
Element(_1 += _2 * _3, a_->grad(), val_, result);
}
};

View File

@ -10,25 +10,42 @@
namespace marian {
void zeros(Tensor t) {
std::vector<float> vals(t.size(), 0.0f);
thrust::copy(vals.begin(), vals.end(), t.begin());
t.set(0.f);
}
void ones(Tensor t) {
std::vector<float> vals(t.size(), 1.0f);
thrust::copy(vals.begin(), vals.end(), t.begin());
t.set(1.0f);
}
void randreal(Tensor t) {
template <class Distribution>
void distribution(Tensor t, float a=0.0, float b=0.1) {
std::random_device device;
std::default_random_engine engine(device());
std::uniform_real_distribution<> dist(0, 0.01);
Distribution dist(a, b);
auto gen = std::bind(dist, engine);
std::vector<float> vals(t.size());
std::generate(begin(vals), end(vals), gen);
thrust::copy(vals.begin(), vals.end(), t.begin());
t << vals;
}
std::function<void(Tensor)> normal(float mean = 0.0, float std = 0.1) {
return [mean, std](Tensor t) {
distribution<std::normal_distribution<float>>(t, mean, std);
};
}
std::function<void(Tensor)> uniform(float a = 0.0, float b = 0.1) {
return [a, b](Tensor t) {
distribution<std::uniform_real_distribution<float>>(t, a, b);
};
}
std::function<void(Tensor)> from_vector(const std::vector<float>& v) {
return [&v](Tensor t) {
t << v;
};
}
} // namespace marian

View File

@ -9,9 +9,10 @@ int main(int argc, char** argv) {
/*auto images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg);*/
/*auto labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numImg);*/
#if 1
using namespace marian;
using namespace keywords;
Expr x = input(shape={1, 2});
Expr y = input(shape={1, 2});
@ -19,7 +20,7 @@ int main(int argc, char** argv) {
//Expr b = param(shape={1, 2}, name="b0");
std::cerr << "Building model...";
auto predict = softmax(dot(x, w),
auto predict = softmax_fast(dot(x, w),
axis=1, name="pred");
auto graph = -mean(sum(y * log(predict), axis=1),
axis=0, name="cost");
@ -41,75 +42,80 @@ int main(int argc, char** argv) {
std::cerr << graph.val().Debug() << std::endl;
std::cerr << w.grad().Debug() << std::endl;
//std::cerr << b.grad().Debug() << std::endl;
#else
// using namespace marian;
// using namespace keywords;
//
// const size_t BATCH_SIZE = 500;
// const size_t IMAGE_SIZE = 784;
// const size_t LABEL_SIZE = 10;
//
// Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X");
// Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y");
//
// Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
// Expr b = param(shape={1, LABEL_SIZE}, name="b0");
//
// Expr z = dot(x, w) + b;
// Expr lr = softmax(z, axis=1, name="pred");
// Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
// //cerr << "x=" << Debug(lr.val().shape()) << endl;
//
// int numofdata;
// //vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
// //vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
// vector<float> images = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", numofdata, IMAGE_SIZE);
// vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", numofdata, LABEL_SIZE);
// cerr << "images=" << images.size() << " labels=" << labels.size() << endl;
// cerr << "numofdata=" << numofdata << endl;
//
// size_t startInd = 0;
// size_t startIndData = 0;
// while (startInd < numofdata) {
// size_t batchSize = (startInd + BATCH_SIZE < numofdata) ? BATCH_SIZE : numofdata - startInd;
// cerr << "startInd=" << startInd
// << " startIndData=" << startIndData
// << " batchSize=" << batchSize << endl;
//
// Tensor tx({numofdata, IMAGE_SIZE}, 1);
// Tensor ty({numofdata, LABEL_SIZE}, 1);
//
// tx.Load(images.begin() + startIndData, images.begin() + startIndData + batchSize * IMAGE_SIZE);
// ty.Load(labels.begin() + startInd, labels.begin() + startInd + batchSize);
//
// //cerr << "tx=" << Debug(tx.shape()) << endl;
// //cerr << "ty=" << Debug(ty.shape()) << endl;
//
// x = tx;
// y = ty;
//
// cerr << "x=" << Debug(x.val().shape()) << endl;
// cerr << "y=" << Debug(y.val().shape()) << endl;
//
//
// graph.forward(batchSize);
//
// cerr << "w=" << Debug(w.val().shape()) << endl;
// cerr << "b=" << Debug(b.val().shape()) << endl;
// std::cerr << "z: " << Debug(z.val().shape()) << endl;
// std::cerr << "lr: " << Debug(lr.val().shape()) << endl;
// std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ;
//
// //std::cerr << "scores=" << scores.val().Debug() << endl;
// //std::cerr << "lr=" << lr.val().Debug() << endl;
//
// //graph.backward();
//
// //std::cerr << graph["pred"].val()[0] << std::endl;
//
// startInd += batchSize;
// startIndData += batchSize * IMAGE_SIZE;
// }
using namespace marian;
using namespace keywords;
using namespace std;
const size_t BATCH_SIZE = 500;
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
Expr x = input(shape={whatevs, IMAGE_SIZE}, name="X");
Expr y = input(shape={whatevs, LABEL_SIZE}, name="Y");
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
Expr z = dot(x, w) + b;
Expr lr = softmax(z, axis=1, name="pred");
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
//cerr << "x=" << Debug(lr.val().shape()) << endl;
int numofdata;
//vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
//vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", numofdata, IMAGE_SIZE);
vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", numofdata, LABEL_SIZE);
cerr << "images=" << images.size() << " labels=" << labels.size() << endl;
cerr << "numofdata=" << numofdata << endl;
size_t startInd = 0;
size_t startIndData = 0;
while (startInd < numofdata) {
size_t batchSize = (startInd + BATCH_SIZE < numofdata) ? BATCH_SIZE : numofdata - startInd;
cerr << "startInd=" << startInd
<< " startIndData=" << startIndData
<< " batchSize=" << batchSize << endl;
Tensor tx({numofdata, IMAGE_SIZE}, 1);
Tensor ty({numofdata, LABEL_SIZE}, 1);
tx.set(images.begin() + startIndData, images.begin() + startIndData + batchSize * IMAGE_SIZE);
ty.set(labels.begin() + startInd, labels.begin() + startInd + batchSize);
//cerr << "tx=" << Debug(tx.shape()) << endl;
//cerr << "ty=" << Debug(ty.shape()) << endl;
x = tx;
y = ty;
cerr << "x=" << Debug(x.val().shape()) << endl;
cerr << "y=" << Debug(y.val().shape()) << endl;
graph.forward(batchSize);
cerr << "w=" << Debug(w.val().shape()) << endl;
cerr << "b=" << Debug(b.val().shape()) << endl;
std::cerr << "z: " << Debug(z.val().shape()) << endl;
std::cerr << "lr: " << Debug(lr.val().shape()) << endl;
std::cerr << "Log-likelihood: " << graph.val().Debug() << endl ;
//std::cerr << "scores=" << scores.val().Debug() << endl;
//std::cerr << "lr=" << lr.val().Debug() << endl;
graph.backward();
std::cerr << w.grad().Debug() << std::endl;
//std::cerr << graph["pred"].val()[0] << std::endl;
startInd += batchSize;
startIndData += batchSize * IMAGE_SIZE;
}
#endif
return 0;
}

View File

@ -9,7 +9,7 @@ using namespace keywords;
int main(int argc, char** argv) {
cudaSetDevice(0);
cudaSetDevice(1);
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
@ -20,7 +20,6 @@ int main(int argc, char** argv) {
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
std::cerr << "Done." << std::endl;
std::cerr << "Loading model params...";
NpzConverter converter("../scripts/test_model_single/model.npz");
@ -36,11 +35,11 @@ int main(int argc, char** argv) {
auto y = input(shape={whatevs, LABEL_SIZE});
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE},
init=[wData](Tensor t) { t.set(wData); });
init=from_vector(wData));
auto b = param(shape={1, LABEL_SIZE},
init=[bData](Tensor t) { t.set(bData); });
init=from_vector(bData));
auto probs = softmax(dot(x, w) + b, axis=1);
auto probs = softmax_fast(dot(x, w) + b, axis=1);
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
std::cerr << "Done." << std::endl;

View File

@ -21,22 +21,42 @@ int main(int argc, char** argv) {
std::cerr << "\tDone." << std::endl;
std::cerr << "Loading model params...";
<<<<<<< HEAD
NpzConverter converter("../scripts/test_model_single/model.npz");
=======
NpzConverter converter("../scripts/test_model_multi/model.npz");
>>>>>>> eba5b462257a9949fd124378c53e9bf7b357b1d3
std::vector<float> wData;
Shape wShape;
converter.Load("weights", wData, wShape);
std::vector<float> wData1;
Shape wShape1;
converter.Load("weights1", wData1, wShape1);
std::vector<float> bData1;
Shape bShape1;
converter.Load("bias1", bData1, bShape1);
std::vector<float> wData2;
Shape wShape2;
converter.Load("weights2", wData2, wShape2);
std::vector<float> bData2;
Shape bShape2;
converter.Load("bias2", bData2, bShape2);
std::vector<float> bData;
Shape bShape;
converter.Load("bias", bData, bShape);
auto initW = [wData](Tensor t) {
t.set(wData);
auto initW1 = [wData1](Tensor t) {
t.set(wData1);
};
auto initB = [bData](Tensor t) {
t.set(bData);
auto initB1 = [bData1](Tensor t) {
t.set(bData1);
};
auto initW2 = [wData2](Tensor t) {
t.set(wData2);
};
auto initB2 = [bData2](Tensor t) {
t.set(bData2);
};
std::cerr << "\tDone." << std::endl;
@ -45,11 +65,15 @@ int main(int argc, char** argv) {
auto x = input(shape={whatevs, IMAGE_SIZE}, name="X");
auto y = input(shape={whatevs, LABEL_SIZE}, name="Y");
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW);
auto b = param(shape={1, LABEL_SIZE}, name="b0", init=initB);
auto w1 = param(shape={IMAGE_SIZE, 100}, name="W0", init=initW1);
auto b1 = param(shape={1, 100}, name="b0", init=initB1);
auto w2 = param(shape={100, LABEL_SIZE}, name="W1", init=initW2);
auto b2 = param(shape={1, LABEL_SIZE}, name="b1", init=initB2);
std::cerr << "Building model...";
auto predict = softmax(dot(x, w) + b, axis=1, name="pred");
auto layer1 = tanh(dot(x, w1) + b1);
auto layer2 = softmax(dot(layer1, w2) + b2, axis=1, name="layer2");
auto predict = layer2;
std::cerr << "Done." << std::endl;
@ -77,6 +101,7 @@ int main(int argc, char** argv) {
if (testLabels[startId * LABEL_SIZE + i + j]) correct = j;
if (results[i + j] > results[i + predicted]) predicted = j;
}
/*std::cerr << "CORRECT: " << correct << " PREDICTED: " << predicted << std::endl;*/
acc += (correct == predicted);
}