mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
Tidying up
This commit is contained in:
parent
aab15d66e6
commit
3c730f053d
@ -12,40 +12,31 @@ int main(int argc, char** argv) {
|
||||
|
||||
const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
int numofdata;
|
||||
|
||||
std::cerr << "Loading test set...";
|
||||
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
std::cerr << "\tDone." << std::endl;
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
const size_t BATCH_SIZE = testLabels.size();
|
||||
|
||||
std::cerr << "Loading model params...";
|
||||
NpzConverter converter("../scripts/test_model/model.npz");
|
||||
|
||||
std::vector<float> wData;
|
||||
Shape wShape;
|
||||
std::vector<float> wData, bData;
|
||||
Shape wShape, bShape;
|
||||
converter.Load("weights", wData, wShape);
|
||||
|
||||
std::vector<float> bData;
|
||||
Shape bShape;
|
||||
converter.Load("bias", bData, bShape);
|
||||
|
||||
auto initW = [wData](Tensor t) {
|
||||
t.set(wData);
|
||||
};
|
||||
|
||||
auto initB = [bData](Tensor t) {
|
||||
t.set(bData);
|
||||
};
|
||||
|
||||
std::cerr << "\tDone." << std::endl;
|
||||
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
auto x = input(shape={whatevs, IMAGE_SIZE}, name="X");
|
||||
auto y = input(shape={whatevs, LABEL_SIZE}, name="Y");
|
||||
|
||||
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0", init=initW);
|
||||
auto b = param(shape={1, LABEL_SIZE}, name="b0", init=initB);
|
||||
auto w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0",
|
||||
init=[wData](Tensor t) { t.set(wData); });
|
||||
auto b = param(shape={1, LABEL_SIZE}, name="b0",
|
||||
init=[bData](Tensor t) {t.set(bData); });
|
||||
|
||||
std::cerr << "Building model...";
|
||||
auto predict = softmax(dot(x, w) + b,
|
||||
@ -61,7 +52,7 @@ int main(int argc, char** argv) {
|
||||
x = xt << testImages;
|
||||
y = yt << testLabels;
|
||||
|
||||
graph.forward(numofdata);
|
||||
graph.forward(BATCH_SIZE);
|
||||
auto results = predict.val();
|
||||
graph.backward();
|
||||
|
||||
@ -85,7 +76,7 @@ int main(int argc, char** argv) {
|
||||
//}
|
||||
//std::cerr << ")" << std::endl;
|
||||
}
|
||||
std::cerr << "ACC: " << float(acc)/numofdata << std::endl;
|
||||
std::cerr << "Accuracy: " << float(acc)/BATCH_SIZE << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user