mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-04 14:04:24 +03:00
renamed num of data to BATCH_SIZE
This commit is contained in:
parent
3c730f053d
commit
73e1d5f96a
@ -12,13 +12,13 @@ int main(int argc, char** argv) {
|
||||
|
||||
const size_t IMAGE_SIZE = 784;
|
||||
const size_t LABEL_SIZE = 10;
|
||||
int BATCH_SIZE = 10000;
|
||||
|
||||
std::cerr << "Loading test set...";
|
||||
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
std::vector<float> testImages = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", BATCH_SIZE, IMAGE_SIZE);
|
||||
std::vector<float> testLabels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", BATCH_SIZE, LABEL_SIZE);
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
const size_t BATCH_SIZE = testLabels.size();
|
||||
|
||||
std::cerr << "Loading model params...";
|
||||
NpzConverter converter("../scripts/test_model/model.npz");
|
||||
@ -43,19 +43,18 @@ int main(int argc, char** argv) {
|
||||
axis=1, name="pred");
|
||||
auto graph = -mean(sum(y * log(predict), axis=1),
|
||||
axis=0, name="cost");
|
||||
|
||||
std::cerr << "Done." << std::endl;
|
||||
|
||||
Tensor xt({numofdata, IMAGE_SIZE});
|
||||
Tensor yt({numofdata, LABEL_SIZE});
|
||||
Tensor xt({BATCH_SIZE, IMAGE_SIZE});
|
||||
Tensor yt({BATCH_SIZE, LABEL_SIZE});
|
||||
|
||||
x = xt << testImages;
|
||||
y = yt << testLabels;
|
||||
|
||||
graph.forward(BATCH_SIZE);
|
||||
auto results = predict.val();
|
||||
graph.backward();
|
||||
|
||||
auto results = predict.val();
|
||||
std::vector<float> resultsv(results.size());
|
||||
resultsv << results;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user