mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Merge ../Marian
This commit is contained in:
commit
1f308d2004
@ -3,6 +3,7 @@
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
import time
|
||||
from keras.datasets import mnist
|
||||
from keras.utils import np_utils
|
||||
from keras.models import Sequential
|
||||
@ -15,7 +16,11 @@ def softmax(x):
|
||||
|
||||
def baseline_model(pixels_count, classes_count):
|
||||
model = Sequential()
|
||||
model.add(Dense(100, input_dim=pixels_count, init='normal', activation='tanh'))
|
||||
model.add(Dense(2000, input_dim=pixels_count, init='normal', activation='tanh'))
|
||||
model.add(Dense(2000, init='normal', activation='tanh'))
|
||||
model.add(Dense(2000, init='normal', activation='tanh'))
|
||||
model.add(Dense(2000, init='normal', activation='tanh'))
|
||||
model.add(Dense(2000, init='normal', activation='tanh'))
|
||||
model.add(Dense(classes_count, input_dim=100, init='normal', activation='softmax'))
|
||||
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
|
||||
return model
|
||||
@ -52,21 +57,24 @@ if __name__ == "__main__":
|
||||
# Build the model
|
||||
model = baseline_model(pixels_count, classes_count)
|
||||
# Fit the model
|
||||
model.fit(X_train, y_train, validation_data=(X_test, y_test), nb_epoch=10, batch_size=200, verbose=2)
|
||||
|
||||
start = time.time();
|
||||
model.fit(X_train, y_train, nb_epoch=10, batch_size=200, verbose=2)
|
||||
print "Time elapsed", time.time() - start, "s"
|
||||
# Final evaluation of the model
|
||||
scores = model.evaluate(X_test, y_test, verbose=0)
|
||||
print("Baseline Error: %.2f%%" % (100-scores[1]*100))
|
||||
print("Accuracy: %.2f%%" % (scores[1] * 100))
|
||||
|
||||
### Weight and bias matrixes - we extract them from the model
|
||||
|
||||
# weights_ones = np.ones((pixels_count, classes_count))
|
||||
# print weights_ones.shape
|
||||
|
||||
weights1, bias1, weights2, bias2 = model.get_weights()
|
||||
#weights1, bias1, weights2, bias2 = model.get_weights()
|
||||
### Save model to npz files
|
||||
if not os.path.exists("test_model_multi"):
|
||||
os.makedirs("test_model_multi")
|
||||
#if not os.path.exists("test_model_multi"):
|
||||
# os.makedirs("test_model_multi")
|
||||
# np.savez("test_model_multi/model", *model)
|
||||
np.savez("test_model_multi/model", weights1 = weights1, bias1 = bias1, weights2 = weights2, bias2 = bias2)
|
||||
#np.savez("test_model_multi/model", weights1 = weights1, bias1 = bias1, weights2 = weights2, bias2 = bias2)
|
||||
|
||||
print "Model saved! Check test_model_multi directory"
|
||||
#print "Model saved! Check test_model_multi directory"
|
||||
|
@ -21,30 +21,25 @@ cuda_add_executable(
|
||||
target_link_libraries(marian marian_lib)
|
||||
|
||||
cuda_add_executable(
|
||||
train_mnist
|
||||
train_mnist.cu
|
||||
mnist_benchmark
|
||||
mnist_benchmark.cu
|
||||
)
|
||||
|
||||
target_link_libraries(train_mnist marian_lib)
|
||||
|
||||
cuda_add_executable(
|
||||
validate_mnist
|
||||
validate_mnist.cu
|
||||
)
|
||||
cuda_add_executable(
|
||||
validate_mnist_batch
|
||||
validate_mnist_batch.cu
|
||||
)
|
||||
|
||||
cuda_add_executable(
|
||||
validate_encoder_decoder
|
||||
validate_encoder_decoder.cu
|
||||
)
|
||||
|
||||
target_link_libraries(validate_mnist marian_lib)
|
||||
target_link_libraries(mnist_benchmark marian_lib)
|
||||
target_link_libraries(validate_mnist_batch marian_lib)
|
||||
target_link_libraries(validate_encoder_decoder marian_lib)
|
||||
|
||||
foreach(exec marian train_mnist validate_mnist validate_mnist_batch validate_encoder_decoder)
|
||||
foreach(exec marian mnist_benchmark validate_mnist_batch validate_encoder_decoder)
|
||||
target_link_libraries(${exec} ${EXT_LIBS} cuda cudnn)
|
||||
cuda_add_cublas_to_target(${exec})
|
||||
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
|
||||
|
164
src/mnist_benchmark.cu
Normal file
164
src/mnist_benchmark.cu
Normal file
@ -0,0 +1,164 @@
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <boost/timer/timer.hpp>
|
||||
|
||||
#include "marian.h"
|
||||
#include "mnist.h"
|
||||
#include "npz_converter.h"
|
||||
#include "optimizers.h"
|
||||
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
int BATCH_SIZE = 200;
|
||||
|
||||
ExpressionGraph build_graph(const std::vector<int>& dims) {
|
||||
std::cerr << "Building model... ";
|
||||
boost::timer::cpu_timer timer;
|
||||
|
||||
ExpressionGraph g;
|
||||
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
|
||||
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
|
||||
|
||||
std::vector<Expr> layers, weights, biases;
|
||||
for(int i = 0; i < dims.size()-1; ++i) {
|
||||
int in = dims[i];
|
||||
int out = dims[i+1];
|
||||
|
||||
if(i == 0) {
|
||||
layers.emplace_back(x);
|
||||
}
|
||||
else {
|
||||
layers.emplace_back(tanh(dot(layers.back(), weights.back())) + biases.back());
|
||||
}
|
||||
|
||||
weights.emplace_back(
|
||||
g.param(shape={in, out},
|
||||
init=normal()));
|
||||
biases.emplace_back(
|
||||
g.param(shape={1, out},
|
||||
init=normal()));
|
||||
|
||||
}
|
||||
|
||||
auto probs = named(
|
||||
softmax(dot(layers.back(), weights.back()) + biases.back()),
|
||||
"probs"
|
||||
);
|
||||
|
||||
auto cost = -mean(sum(y * log(probs), axis=1), axis=0);
|
||||
auto costreg = named(
|
||||
cost, "cost"
|
||||
);
|
||||
|
||||
std::cerr << timer.format(5, "%ws") << std::endl;
|
||||
return g;
|
||||
}
|
||||
|
||||
void shuffle(std::vector<float>& x, std::vector<float>& y, size_t dimx, size_t dimy) {
|
||||
std::srand(std::time(0));
|
||||
std::vector<size_t> ind;
|
||||
for(size_t i = 0; i < x.size() / dimx; ++i) {
|
||||
ind.push_back(i);
|
||||
}
|
||||
|
||||
std::random_shuffle(ind.begin(), ind.end());
|
||||
|
||||
std::vector<float> xShuffled(x.size());
|
||||
std::vector<float> yShuffled(y.size());
|
||||
|
||||
int j = 0;
|
||||
for(auto i : ind) {
|
||||
std::copy(x.begin() + j * dimx, x.begin() + j * dimx + dimx, xShuffled.begin() + i * dimx);
|
||||
std::copy(y.begin() + j * dimy, y.begin() + j * dimy + dimy, yShuffled.begin() + i * dimy);
|
||||
j++;
|
||||
}
|
||||
|
||||
x = xShuffled;
|
||||
y = yShuffled;
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
int trainRows, testRows;
|
||||
|
||||
std::cerr << "Loading train set...";
|
||||
std::vector<float> trainImages = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", trainRows, IMAGE_SIZE);
|
||||
std::vector<float> trainLabels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", trainRows, LABEL_SIZE);
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
std::cerr << "Loading test set...";
|
||||
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", testRows, IMAGE_SIZE);
|
||||
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", testRows, LABEL_SIZE);
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
ExpressionGraph g = build_graph({IMAGE_SIZE, 2000, 2000, 2000, 2000, 2000, LABEL_SIZE});
|
||||
|
||||
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
|
||||
Tensor yt({BATCH_SIZE, LABEL_SIZE});
|
||||
|
||||
|
||||
boost::timer::cpu_timer total;
|
||||
Adam opt;
|
||||
for(int i = 1; i <= 10; ++i) {
|
||||
boost::timer::cpu_timer timer;
|
||||
shuffle(trainImages, trainLabels, IMAGE_SIZE, LABEL_SIZE);
|
||||
float cost = 0;
|
||||
for(int j = 0; j < trainRows / BATCH_SIZE; j++) {
|
||||
size_t xBatch = IMAGE_SIZE * BATCH_SIZE;
|
||||
auto xbegin = trainImages.begin() + j * xBatch;
|
||||
auto xend = xbegin + xBatch;
|
||||
xt.set(xbegin, xend);
|
||||
|
||||
size_t yBatch = LABEL_SIZE * BATCH_SIZE;
|
||||
auto ybegin = trainLabels.begin() + j * yBatch;
|
||||
auto yend = ybegin + yBatch;
|
||||
yt.set(ybegin, yend);
|
||||
|
||||
g["x"] = xt;
|
||||
g["y"] = yt;
|
||||
|
||||
opt(g, BATCH_SIZE);
|
||||
cost += g["cost"].val()[0];
|
||||
}
|
||||
std::cerr << "Epoch: " << i << " - Cost: " << cost / trainRows * BATCH_SIZE << " - " << timer.format(3, "%ws") << std::endl;
|
||||
}
|
||||
std::cerr << "Total: " << total.format(3, "%ws") << std::endl;
|
||||
|
||||
std::vector<float> results;
|
||||
for(int j = 0; j < testRows / BATCH_SIZE; j++) {
|
||||
size_t xBatch = IMAGE_SIZE * BATCH_SIZE;
|
||||
auto xbegin = testImages.begin() + j * xBatch;
|
||||
auto xend = xbegin + xBatch;
|
||||
xt.set(xbegin, xend);
|
||||
yt.set(0);
|
||||
|
||||
g["x"] = xt;
|
||||
g["y"] = yt;
|
||||
|
||||
g.forward(BATCH_SIZE);
|
||||
|
||||
std::vector<float> bResults;
|
||||
bResults << g["probs"].val();
|
||||
results.insert(results.end(), bResults.begin(), bResults.end());
|
||||
}
|
||||
|
||||
size_t acc = 0;
|
||||
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
||||
size_t correct = 0;
|
||||
size_t proposed = 0;
|
||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||
if (testLabels[i + j])
|
||||
correct = j;
|
||||
if (results[i + j] > results[i + proposed])
|
||||
proposed = j;
|
||||
}
|
||||
acc += (correct == proposed);
|
||||
}
|
||||
std::cerr << "Accuracy: " << float(acc) / testRows << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
@ -114,7 +114,7 @@ struct LogitNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -138,7 +138,7 @@ struct TanhNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -180,7 +180,7 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -209,7 +209,7 @@ struct ArgmaxNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -232,7 +232,7 @@ struct LogNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -255,7 +255,7 @@ struct ExpNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -277,7 +277,7 @@ struct NegNodeOp : public UnaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -330,8 +330,8 @@ struct DotNodeOp : public BinaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -357,8 +357,8 @@ struct PlusNodeOp : public BinaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -384,8 +384,8 @@ struct MinusNodeOp : public BinaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
@ -411,8 +411,8 @@ struct MultNodeOp : public BinaryNodeOp {
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"•\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << &*a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << &*b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
|
@ -6,9 +6,15 @@
|
||||
|
||||
namespace marian {
|
||||
|
||||
// @TODO: modify computation graph to group all paramters in single matrix object.
|
||||
// This will allow to perform a single large SGD update per batch. Currently there
|
||||
// are as many updates as different paramters.
|
||||
|
||||
// @TODO: Implement Element(...) with multiple functors for compacting of calls.
|
||||
|
||||
class Sgd {
|
||||
public:
|
||||
Sgd(float eta=0.001) : eta_(eta) {}
|
||||
Sgd(float eta=0.01) : eta_(eta) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
graph.backprop(batchSize);
|
||||
@ -25,7 +31,7 @@ class Sgd {
|
||||
// @TODO: Add serialization for historic gradients and parameters
|
||||
class Adagrad {
|
||||
public:
|
||||
Adagrad(float eta=0.001, float eps=10e-8)
|
||||
Adagrad(float eta=0.01, float eps=1e-8)
|
||||
: eta_(eta), eps_(eps) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
@ -55,7 +61,7 @@ class Adagrad {
|
||||
// https://arxiv.org/pdf/1412.6980v8.pdf
|
||||
class Adam {
|
||||
public:
|
||||
Adam(float eta=0.001, float beta1=0.999, float beta2=0.999, float eps=10e-8)
|
||||
Adam(float eta=0.001, float beta1=0.9, float beta2=0.999, float eps=1e-8)
|
||||
: eta_(eta), beta1_(beta1), beta2_(beta2), eps_(eps), t_(0) {}
|
||||
|
||||
void operator()(ExpressionGraph& graph, int batchSize) {
|
||||
|
@ -33,13 +33,13 @@ void distribution(Tensor t, float a, float b) {
|
||||
t << vals;
|
||||
}
|
||||
|
||||
std::function<void(Tensor)> normal(float mean = 0.0, float std = 0.1) {
|
||||
std::function<void(Tensor)> normal(float mean = 0.0, float std = 0.05) {
|
||||
return [mean, std](Tensor t) {
|
||||
distribution<std::normal_distribution<float>>(t, mean, std);
|
||||
};
|
||||
}
|
||||
|
||||
std::function<void(Tensor)> uniform(float a = 0.0, float b = 0.1) {
|
||||
std::function<void(Tensor)> uniform(float a = 0.0, float b = 0.05) {
|
||||
return [a, b](Tensor t) {
|
||||
distribution<std::uniform_real_distribution<float>>(t, a, b);
|
||||
};
|
||||
|
@ -4,6 +4,14 @@ using namespace std;
|
||||
|
||||
namespace marian {
|
||||
|
||||
// @TODO: handle this better, maybe per thread?
|
||||
static cublasHandle_t create_handle() {
|
||||
cublasHandle_t cublasHandle;
|
||||
cublasCreate(&cublasHandle);
|
||||
return cublasHandle;
|
||||
}
|
||||
cublasHandle_t cublasHandle = create_handle();
|
||||
|
||||
__global__ void gSubtractMean(float* out, float* weights,
|
||||
size_t rows, size_t cols) {
|
||||
for(int bid = 0; bid < rows; bid += gridDim.x) {
|
||||
@ -212,10 +220,7 @@ Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
|
||||
Tensor Prod(Tensor C, const Tensor A, const Tensor B,
|
||||
bool transA, bool transB, Float beta) {
|
||||
|
||||
cublasHandle_t cublasHandle;
|
||||
cublasCreate(&cublasHandle);
|
||||
Tensor temp = Prod(cublasHandle, C, A, B, transA, transB, beta);
|
||||
cublasDestroy(cublasHandle);
|
||||
return temp;
|
||||
}
|
||||
|
||||
|
@ -1,34 +0,0 @@
|
||||
|
||||
#include "marian.h"
|
||||
#include "mnist.h"
|
||||
#include "optimizers.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
int numofdata;
|
||||
|
||||
std::vector<float> trainImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||
std::vector<float> trainLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
ExpressionGraph g;
|
||||
|
||||
Expr x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
|
||||
Expr y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
|
||||
|
||||
Expr w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE}), "w");
|
||||
Expr b = named(g.param(shape={1, LABEL_SIZE}), "b");
|
||||
|
||||
auto scores = dot(x, w) + b;
|
||||
auto lr = softmax(scores);
|
||||
auto cost = named(-mean(sum(y * log(lr), axis=1), axis=0), "cost");
|
||||
std::cerr << "lr=" << lr.Debug() << std::endl;
|
||||
|
||||
Adagrad opt;
|
||||
opt(g, 300);
|
||||
|
||||
return 0;
|
||||
}
|
@ -1,90 +0,0 @@
|
||||
|
||||
#include "marian.h"
|
||||
#include "mnist.h"
|
||||
#include "npz_converter.h"
|
||||
#include "optimizers.h"
|
||||
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
int BATCH_SIZE = 10000;
|
||||
|
||||
ExpressionGraph build_graph() {
|
||||
std::cerr << "Loading model params...";
|
||||
NpzConverter converter("../scripts/test_model_single/model.npz");
|
||||
|
||||
std::vector<float> wData, bData;
|
||||
Shape wShape, bShape;
|
||||
converter.Load("weights", wData, wShape);
|
||||
converter.Load("bias", bData, bShape);
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
std::cerr << "Building model...";
|
||||
|
||||
ExpressionGraph g;
|
||||
auto x = named(g.input(shape={whatevs, IMAGE_SIZE}), "x");
|
||||
auto y = named(g.input(shape={whatevs, LABEL_SIZE}), "y");
|
||||
|
||||
auto w = named(g.param(shape={IMAGE_SIZE, LABEL_SIZE},
|
||||
init=from_vector(wData)), "w");
|
||||
auto b = named(g.param(shape={1, LABEL_SIZE},
|
||||
init=from_vector(bData)), "b");
|
||||
|
||||
|
||||
auto probs = named(
|
||||
softmax(dot(x, w) + b),
|
||||
"probs"
|
||||
);
|
||||
|
||||
auto cost = named(
|
||||
-mean(sum(y * log(probs), axis=1), axis=0),
|
||||
"cost"
|
||||
);
|
||||
|
||||
std::cerr << "Done." << std::endl;
|
||||
return g;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
std::cerr << "Loading test set...";
|
||||
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE);
|
||||
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
ExpressionGraph g = build_graph();
|
||||
|
||||
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
|
||||
Tensor yt({BATCH_SIZE, LABEL_SIZE});
|
||||
|
||||
g["x"] = (xt << testImages);
|
||||
g["y"] = (yt << testLabels);
|
||||
|
||||
Adam opt;
|
||||
for(size_t j = 0; j < 20; ++j) {
|
||||
for(size_t i = 0; i < 60; ++i) {
|
||||
opt(g, BATCH_SIZE);
|
||||
}
|
||||
std::cerr << g["cost"].val()[0] << std::endl;
|
||||
}
|
||||
|
||||
//std::cout << g.graphviz() << std::endl;
|
||||
|
||||
std::vector<float> results;
|
||||
results << g["probs"].val();
|
||||
|
||||
size_t acc = 0;
|
||||
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
||||
size_t correct = 0;
|
||||
size_t proposed = 0;
|
||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||
if (testLabels[i+j]) correct = j;
|
||||
if (results[i + j] > results[i + proposed]) proposed = j;
|
||||
}
|
||||
acc += (correct == proposed);
|
||||
}
|
||||
std::cerr << "Cost: " << g["cost"].val()[0] << " - Accuracy: " << float(acc) / BATCH_SIZE << std::endl;
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue
Block a user