mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
Speed up validation; add batch.
This commit is contained in:
parent
74626d347f
commit
4b730f7645
@ -9,11 +9,12 @@ using namespace keywords;
|
|||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
const size_t IMAGE_SIZE = 784;
|
const size_t IMAGE_SIZE = 784;
|
||||||
const size_t LABEL_SIZE = 10;
|
const size_t LABEL_SIZE = 10;
|
||||||
|
const size_t BATCH_SIZE = 24;
|
||||||
int numofdata;
|
int numofdata;
|
||||||
|
|
||||||
std::cerr << "Loading test set...";
|
std::cerr << "Loading test set...";
|
||||||
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||||
std::vector<float>testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||||
std::cerr << "\tDone." << std::endl;
|
std::cerr << "\tDone." << std::endl;
|
||||||
|
|
||||||
std::cerr << "Loading model params...";
|
std::cerr << "Loading model params...";
|
||||||
@ -48,29 +49,61 @@ int main(int argc, char** argv) {
|
|||||||
auto predict = softmax(scores, axis=1, name="pred");
|
auto predict = softmax(scores, axis=1, name="pred");
|
||||||
std::cerr << "\tDone." << std::endl;
|
std::cerr << "\tDone." << std::endl;
|
||||||
|
|
||||||
Tensor xt({numofdata, IMAGE_SIZE});
|
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
|
||||||
xt.Load(testImages);
|
/* xt.Load(testImages); */
|
||||||
|
/* x = xt; */
|
||||||
predict.forward(numofdata);
|
|
||||||
|
|
||||||
auto results = predict.val();
|
|
||||||
|
|
||||||
size_t acc = 0;
|
size_t acc = 0;
|
||||||
|
size_t startId = 0;
|
||||||
|
size_t endId = startId + BATCH_SIZE;
|
||||||
|
|
||||||
for (size_t i = 0; i < testLabels.size(); i += LABEL_SIZE) {
|
while (endId < numofdata) {
|
||||||
size_t correct = 0;
|
std::vector<float> tmp(testImages.begin() + (startId * IMAGE_SIZE),
|
||||||
size_t predicted = 0;
|
testImages.begin() + (endId * IMAGE_SIZE));
|
||||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
xt.Load(tmp);
|
||||||
if (testLabels[i+j]) correct = j;
|
x = xt;
|
||||||
if (results[i + j] > results[i + predicted]) predicted = j;
|
|
||||||
|
predict.forward(BATCH_SIZE);
|
||||||
|
|
||||||
|
thrust::host_vector<float> results(predict.val().begin(), predict.val().begin() + LABEL_SIZE * BATCH_SIZE);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < BATCH_SIZE * LABEL_SIZE; i += LABEL_SIZE) {
|
||||||
|
size_t correct = 0;
|
||||||
|
size_t predicted = 0;
|
||||||
|
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||||
|
if (testLabels[startId * LABEL_SIZE + i + j]) correct = j;
|
||||||
|
if (results[i + j] > results[i + predicted]) predicted = j;
|
||||||
|
}
|
||||||
|
acc += (correct == predicted);
|
||||||
}
|
}
|
||||||
acc += (correct == predicted);
|
|
||||||
std::cerr << "corect: " << correct << " | " << predicted << "(";
|
startId += BATCH_SIZE;
|
||||||
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
endId += BATCH_SIZE;
|
||||||
std::cerr << results[i+j] << " ";
|
|
||||||
}
|
|
||||||
std::cerr << std::endl;
|
|
||||||
}
|
}
|
||||||
|
if (endId != numofdata) {
|
||||||
|
endId = numofdata;
|
||||||
|
if (endId - startId >= 0) {
|
||||||
|
std::vector<float> tmp(testImages.begin() + (startId * IMAGE_SIZE),
|
||||||
|
testImages.begin() + (endId * IMAGE_SIZE));
|
||||||
|
xt.Load(tmp);
|
||||||
|
x = xt;
|
||||||
|
|
||||||
|
predict.forward(endId - startId);
|
||||||
|
|
||||||
|
thrust::host_vector<float> results(predict.val().begin(), predict.val().begin() + LABEL_SIZE * (endId - startId));
|
||||||
|
|
||||||
|
for (size_t i = 0; i < (endId - startId) * LABEL_SIZE; i += LABEL_SIZE) {
|
||||||
|
size_t correct = 0;
|
||||||
|
size_t predicted = 0;
|
||||||
|
for (size_t j = 0; j < LABEL_SIZE; ++j) {
|
||||||
|
if (testLabels[startId * LABEL_SIZE + i + j]) correct = j;
|
||||||
|
if (results[i + j] > results[i + predicted]) predicted = j;
|
||||||
|
}
|
||||||
|
acc += (correct == predicted);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::cerr << "ACC: " << float(acc)/numofdata << std::endl;
|
std::cerr << "ACC: " << float(acc)/numofdata << std::endl;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
Loading…
Reference in New Issue
Block a user