mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
debug
This commit is contained in:
parent
81f6f51f6f
commit
9643f52aa5
@ -172,11 +172,13 @@ class TensorImpl {
|
|||||||
strm << "shape=" << marian::Debug(shape_) << std::endl;
|
strm << "shape=" << marian::Debug(shape_) << std::endl;
|
||||||
|
|
||||||
// values
|
// values
|
||||||
/*
|
|
||||||
size_t totSize = GetTotalSize(shape());
|
size_t totSize = GetTotalSize(shape());
|
||||||
std::vector<Float> values(totSize);
|
std::vector<Float> values(totSize);
|
||||||
thrust::copy(data_.begin(), data_.end(), values.begin());
|
thrust::copy(data_.begin(), data_.end(), values.begin());
|
||||||
*/
|
|
||||||
|
for (size_t i = 0; i < totSize; ++i) {
|
||||||
|
strm << values[i] << " ";
|
||||||
|
}
|
||||||
return strm.str();
|
return strm.str();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
14
src/test.cu
14
src/test.cu
@ -21,10 +21,10 @@ int main(int argc, char** argv) {
|
|||||||
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
|
Expr w = param(shape={IMAGE_SIZE, LABEL_SIZE}, name="W0");
|
||||||
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
|
Expr b = param(shape={1, LABEL_SIZE}, name="b0");
|
||||||
|
|
||||||
auto scores = dot(x, w) + b;
|
Expr scores = dot(x, w) + b;
|
||||||
auto lr = softmax(scores, axis=1, name="pred");
|
Expr lr = softmax(scores, axis=1, name="pred");
|
||||||
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||||
cerr << "lr=" << lr.Debug() << endl;
|
cerr << "lr=" << Debug(lr.val().shape()) << endl;
|
||||||
|
|
||||||
int numofdata;
|
int numofdata;
|
||||||
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numofdata, IMAGE_SIZE);
|
||||||
@ -38,8 +38,8 @@ int main(int argc, char** argv) {
|
|||||||
tx.Load(images);
|
tx.Load(images);
|
||||||
ty.Load(labels);
|
ty.Load(labels);
|
||||||
|
|
||||||
cerr << "tx=" << tx.Debug() << endl;
|
cerr << "tx=" << Debug(tx.shape()) << endl;
|
||||||
cerr << "ty=" << ty.Debug() << endl;
|
cerr << "ty=" << Debug(ty.shape()) << endl;
|
||||||
|
|
||||||
x = tx;
|
x = tx;
|
||||||
y = ty;
|
y = ty;
|
||||||
@ -50,6 +50,8 @@ int main(int argc, char** argv) {
|
|||||||
std::cerr << "lr: " << Debug(lr.val().shape()) << endl;
|
std::cerr << "lr: " << Debug(lr.val().shape()) << endl;
|
||||||
std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ;
|
std::cerr << "Log-likelihood: " << Debug(graph.val().shape()) << endl ;
|
||||||
|
|
||||||
|
std::cerr << "scores=" << scores.val().Debug() << endl;
|
||||||
|
|
||||||
graph.backward();
|
graph.backward();
|
||||||
|
|
||||||
//std::cerr << graph["pred"].val()[0] << std::endl;
|
//std::cerr << graph["pred"].val()[0] << std::endl;
|
||||||
|
Loading…
Reference in New Issue
Block a user