mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
debug
This commit is contained in:
parent
6974ceb9d1
commit
a573eecf5c
20
src/test.cu
20
src/test.cu
@ -24,11 +24,15 @@ int main(int argc, char** argv) {
|
||||
Expr z = dot(x, w) + b;
|
||||
Expr lr = softmax(z, axis=1, name="pred");
|
||||
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
//cerr << "lr=" << Debug(lr.val().shape()) << endl;
|
||||
//cerr << "x=" << Debug(lr.val().shape()) << endl;
|
||||
|
||||
int numofdata;
|
||||
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||
vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
//vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||
//vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/train-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||
vector<float> labels = datasets::mnist::ReadLabels("../examples/mnist/train-labels-idx1-ubyte", numofdata, LABEL_SIZE);
|
||||
|
||||
|
||||
cerr << "images=" << images.size() << " labels=" << labels.size() << endl;
|
||||
cerr << "numofdata=" << numofdata << endl;
|
||||
|
||||
@ -38,14 +42,20 @@ int main(int argc, char** argv) {
|
||||
tx.Load(images);
|
||||
ty.Load(labels);
|
||||
|
||||
cerr << "tx=" << Debug(tx.shape()) << endl;
|
||||
cerr << "ty=" << Debug(ty.shape()) << endl;
|
||||
//cerr << "tx=" << Debug(tx.shape()) << endl;
|
||||
//cerr << "ty=" << Debug(ty.shape()) << endl;
|
||||
|
||||
x = tx;
|
||||
y = ty;
|
||||
|
||||
cerr << "x=" << Debug(x.val().shape()) << endl;
|
||||
cerr << "y=" << Debug(y.val().shape()) << endl;
|
||||
|
||||
|
||||
graph.forward(500);
|
||||
|
||||
cerr << "w=" << Debug(w.val().shape()) << endl;
|
||||
cerr << "b=" << Debug(b.val().shape()) << endl;
|
||||
std::cerr << "z: " << Debug(z.val().shape()) << endl;
|
||||
std::cerr << "lr: " << Debug(lr.val().shape()) << endl;
|
||||
std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ;
|
||||
|
Loading…
Reference in New Issue
Block a user