This commit is contained in:
Hieu Hoang 2016-09-13 19:07:12 +02:00
parent e354ea7570
commit 482c0df90a
4 changed files with 18 additions and 0 deletions

View File

@ -85,6 +85,11 @@
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/marian.h</locationURI>
</link>
<link>
<name>mnist.h</name>
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/mnist.h</locationURI>
</link>
<link>
<name>tensor.cu</name>
<type>1</type>

View File

@ -81,7 +81,12 @@ void Tensor::Load(const std::string &path)
}
strm.close();
Load(hostData);
}
void Tensor::Load(const std::vector<float> &values)
{
pimpl_->set(values);
}
}

View File

@ -245,6 +245,7 @@ class Tensor {
}
void Load(const std::string &path);
void Load(const std::vector<float> &values);
};

View File

@ -27,6 +27,11 @@ int main(int argc, char** argv) {
Tensor tx({500, 784}, 1);
Tensor ty({500, 10}, 1);
int numImg, imgSize;
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg, imgSize);
vector<int> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");
cerr << "tx=" << tx.Debug() << endl;
cerr << "ty=" << ty.Debug() << endl;
@ -56,6 +61,7 @@ int main(int argc, char** argv) {
//std::cerr << graph["pred"].val()[0] << std::endl;
*/
// XOR
Expr x = input(shape={whatevs, 2}, name="X");
Expr y = input(shape={whatevs, 2}, name="Y");
@ -77,6 +83,7 @@ int main(int argc, char** argv) {
tx.Load("../examples/xor/train.txt");
ty.Load("../examples/xor/label.txt");
#if 0
hook0(graph);
graph.autodiff();