Tidying up

This commit is contained in:
Marcin Junczys-Dowmunt 2016-09-14 23:36:05 +02:00
parent aab15d66e6
commit 3c730f053d

View File

@ -12,40 +12,31 @@ int main(int argc, char** argv) {
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
int numofdata;
std::cerr << "Loading test set...";
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
std::cerr << "\tDone." << std::endl;
std::cerr << "Done." << std::endl;
const size_t BATCH_SIZE = testLabels.size();
std::cerr << "Loading model params...";
NpzConverter converter("../scripts/test_model/model.npz");
std::vector<float> wData;
Shape wShape;
std::vector<float> wData, bData;
Shape wShape, bShape;
converter.Load("weights", wData, wShape);
std::vector<float> bData;
Shape bShape;
converter.Load("bias", bData, bShape);
auto initW = [wData](Tensor t) {
t.set(wData);
};
auto initB = [bData](Tensor t) {
t.set(bData);
};
std::cerr << "\tDone." << std::endl;
std::cerr << "Done." << std::endl;
auto x = input(shape={whatevs, IMAGE_SIZE}, name="X");
auto y = input(shape={whatevs, LABEL_SIZE}, name="Y");
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW);
auto b = param(shape={1, LABEL_SIZE}, name="b0", init=initB);
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0",
init=[wData](Tensor t) { t.set(wData); });
auto b = param(shape={1, LABEL_SIZE}, name="b0",
init=[bData](Tensor t) {t.set(bData); });
std::cerr << "Building model...";
auto predict = softmax(dot(x, w) + b,
@ -61,7 +52,7 @@ int main(int argc, char** argv) {
x = xt << testImages;
y = yt << testLabels;
graph.forward(numofdata);
graph.forward(BATCH_SIZE);
auto results = predict.val();
graph.backward();
@ -85,7 +76,7 @@ int main(int argc, char** argv) {
//}
//std::cerr << ")" << std::endl;
}
std::cerr << "ACC: " << float(acc)/numofdata << std::endl;
std::cerr << "Accuracy: " << float(acc)/BATCH_SIZE << std::endl;
return 0;
}