Speed up validation; add batch.

This commit is contained in:
Tomasz Dwojak 2016-09-14 22:28:45 +01:00
parent 74626d347f
commit 4b730f7645

View File

@ -9,11 +9,12 @@ using namespace keywords;
int main(int argc, char** argv) {
const size_t IMAGE_SIZE = 784;
const size_t LABEL_SIZE = 10;
const size_t BATCH_SIZE = 24;
int numofdata;
std::cerr << "Loading test set...";
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
std::vector<float>testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
std::cerr << "\tDone." << std::endl;
std::cerr << "Loading model params...";
@ -48,29 +49,61 @@ int main(int argc, char** argv) {
auto predict = softmax(scores, axis=1, name="pred");
std::cerr << "\tDone." << std::endl;
Tensor xt({numofdata, IMAGE_SIZE});
xt.Load(testImages);
predict.forward(numofdata);
auto results = predict.val();
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
/* xt.Load(testImages); */
/* x = xt; */
size_t acc = 0;
size_t startId = 0;
size_t endId = startId + BATCH_SIZE;
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
size_t correct = 0;
size_t predicted = 0;
for (size_t j = 0; j < LABEL_SIZE; ++j) {
if (testLabels[i+j]) correct = j;
if (results[i + j] > results[i + predicted]) predicted = j;
while (endId < numofdata) {
std::vector<float> tmp(testImages.begin() + (startId * IMAGE_SIZE),
testImages.begin() + (endId * IMAGE_SIZE));
xt.Load(tmp);
x = xt;
predict.forward(BATCH_SIZE);
thrust::host_vector<float> results(predict.val().begin(), predict.val().begin() + LABEL_SIZE * BATCH_SIZE);
for (size_t i = 0; i < BATCH_SIZE * LABEL_SIZE; i += LABEL_SIZE) {
size_t correct = 0;
size_t predicted = 0;
for (size_t j = 0; j < LABEL_SIZE; ++j) {
if (testLabels[startId * LABEL_SIZE + i + j]) correct = j;
if (results[i + j] > results[i + predicted]) predicted = j;
}
acc += (correct == predicted);
}
acc += (correct == predicted);
std::cerr << "corect: " << correct << " | " << predicted << "(";
for (size_t j = 0; j < LABEL_SIZE; ++j) {
std::cerr << results[i+j] << " ";
}
std::cerr << std::endl;
startId += BATCH_SIZE;
endId += BATCH_SIZE;
}
if (endId != numofdata) {
endId = numofdata;
if (endId - startId >= 0) {
std::vector<float> tmp(testImages.begin() + (startId * IMAGE_SIZE),
testImages.begin() + (endId * IMAGE_SIZE));
xt.Load(tmp);
x = xt;
predict.forward(endId - startId);
thrust::host_vector<float> results(predict.val().begin(), predict.val().begin() + LABEL_SIZE * (endId - startId));
for (size_t i = 0; i < (endId - startId) * LABEL_SIZE; i += LABEL_SIZE) {
size_t correct = 0;
size_t predicted = 0;
for (size_t j = 0; j < LABEL_SIZE; ++j) {
if (testLabels[startId * LABEL_SIZE + i + j]) correct = j;
if (results[i + j] > results[i + predicted]) predicted = j;
}
acc += (correct == predicted);
}
}
}
std::cerr << "ACC: " << float(acc)/numofdata << std::endl;
return 0;