diff --git a/marian/.project b/marian/.project
index 215485f6..e5b195c5 100644
--- a/marian/.project
+++ b/marian/.project
@@ -85,6 +85,11 @@
1
PARENT-1-PROJECT_LOC/src/marian.h
+
+ mnist.h
+ 1
+ PARENT-1-PROJECT_LOC/src/mnist.h
+
tensor.cu
1
diff --git a/src/tensor.cu b/src/tensor.cu
index 95c684b3..6048bb43 100644
--- a/src/tensor.cu
+++ b/src/tensor.cu
@@ -81,7 +81,12 @@ void Tensor::Load(const std::string &path)
}
strm.close();
+ Load(hostData);
+}
+void Tensor::Load(const std::vector &values)
+{
+ pimpl_->set(values);
}
}
diff --git a/src/tensor.h b/src/tensor.h
index e8ff92bf..a801cd2a 100644
--- a/src/tensor.h
+++ b/src/tensor.h
@@ -245,6 +245,7 @@ class Tensor {
}
void Load(const std::string &path);
+ void Load(const std::vector &values);
};
diff --git a/src/test.cu b/src/test.cu
index ed43f052..1e07bdb3 100644
--- a/src/test.cu
+++ b/src/test.cu
@@ -27,6 +27,11 @@ int main(int argc, char** argv) {
Tensor tx({500, 784}, 1);
Tensor ty({500, 10}, 1);
+
+ int numImg, imgSize;
+ vector images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg, imgSize);
+ vector labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");
+
cerr << "tx=" << tx.Debug() << endl;
cerr << "ty=" << ty.Debug() << endl;
@@ -56,6 +61,7 @@ int main(int argc, char** argv) {
//std::cerr << graph["pred"].val()[0] << std::endl;
*/
+ // XOR
Expr x = input(shape={whatevs, 2}, name="X");
Expr y = input(shape={whatevs, 2}, name="Y");
@@ -77,6 +83,7 @@ int main(int argc, char** argv) {
tx.Load("../examples/xor/train.txt");
ty.Load("../examples/xor/label.txt");
+
#if 0
hook0(graph);
graph.autodiff();