This commit is contained in:
Hieu Hoang 2016-09-13 19:13:49 +02:00
parent 482c0df90a
commit 93eb3ca7ab
2 changed files with 11 additions and 4 deletions

View File

@ -153,7 +153,11 @@ class TensorImpl {
}
void set(const std::vector<Float> &values) {
thrust::copy(values.begin(), values.end(), data_.begin());
size_t totSize = std::accumulate(shape().begin(), shape().end(),
1, std::multiplies<int>());
std::cerr << "totSize=" << totSize << " " << values.size() << std::endl;
assert(totSize == values.size());
thrust::copy(values.begin(), values.end(), data_.begin());
}
std::string Debug() const

View File

@ -12,7 +12,7 @@ int main(int argc, char** argv) {
using namespace marian;
using namespace keywords;
/*
Expr x = input(shape={whatevs, 784}, name="X");
Expr y = input(shape={whatevs, 10}, name="Y");
@ -31,6 +31,8 @@ int main(int argc, char** argv) {
int numImg, imgSize;
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg, imgSize);
vector<int> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");
tx.Load(images);
//ty.Load(labels);
cerr << "tx=" << tx.Debug() << endl;
cerr << "ty=" << ty.Debug() << endl;
@ -59,9 +61,10 @@ int main(int argc, char** argv) {
graph.backward();
//std::cerr << graph["pred"].val()[0] << std::endl;
*/
// XOR
/*
Expr x = input(shape={whatevs, 2}, name="X");
Expr y = input(shape={whatevs, 2}, name="Y");
@ -82,7 +85,7 @@ int main(int argc, char** argv) {
tx.Load("../examples/xor/train.txt");
ty.Load("../examples/xor/label.txt");
*/
#if 0
hook0(graph);