This commit is contained in:
romang 2016-09-14 09:23:57 +02:00
commit d62b899ed2
4 changed files with 38 additions and 9 deletions

View File

@ -85,6 +85,11 @@
<type>1</type> <type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/marian.h</locationURI> <locationURI>PARENT-1-PROJECT_LOC/src/marian.h</locationURI>
</link> </link>
<link>
<name>mnist.h</name>
<type>1</type>
<locationURI>PARENT-1-PROJECT_LOC/src/mnist.h</locationURI>
</link>
<link> <link>
<name>tensor.cu</name> <name>tensor.cu</name>
<type>1</type> <type>1</type>

View File

@ -59,23 +59,34 @@ inline std::vector<T> Tokenize( const std::string &input
void Tensor::Load(const std::string &path) void Tensor::Load(const std::string &path)
{ {
size_t totSize = std::accumulate(pimpl_->shape().begin(), pimpl_->shape().end(),
1, std::multiplies<int>());
cerr << "totSize=" << totSize << endl;
std::vector<float> hostData(totSize);
fstream strm; fstream strm;
strm.open(path.c_str()); strm.open(path.c_str());
size_t lineNum = 0;
string line; string line;
size_t ind = 0;
while ( getline (strm, line) ) while ( getline (strm, line) )
{ {
cerr << line << '\n'; cerr << line << '\n';
vector<Float> toks = Tokenize<Float>(line); vector<Float> toks = Tokenize<Float>(line);
for (size_t i = 0; i < toks.size(); ++i) { for (size_t i = 0; i < toks.size(); ++i) {
pimpl_->set(toks[i], lineNum, i); hostData[ind] = toks[i];
} }
++lineNum; ++ind;
} }
strm.close(); strm.close();
Load(hostData);
}
void Tensor::Load(const std::vector<float> &values)
{
pimpl_->set(values);
} }
} }

View File

@ -152,10 +152,12 @@ class TensorImpl {
thrust::fill(data_.begin(), data_.end(), value); thrust::fill(data_.begin(), data_.end(), value);
} }
void set(value_type value, size_t x, size_t y) { void set(const std::vector<Float> &values) {
assert(shape().size() == 2); size_t totSize = std::accumulate(shape().begin(), shape().end(),
size_t sizeRow = sizeof(Float) * shape()[1]; 1, std::multiplies<int>());
data_[x + sizeRow * y] = value; std::cerr << "totSize=" << totSize << " " << values.size() << std::endl;
assert(totSize == values.size());
thrust::copy(values.begin(), values.end(), data_.begin());
} }
std::string Debug() const std::string Debug() const
@ -247,6 +249,7 @@ class Tensor {
} }
void Load(const std::string &path); void Load(const std::string &path);
void Load(const std::vector<float> &values);
}; };

View File

@ -12,7 +12,7 @@ int main(int argc, char** argv) {
using namespace marian; using namespace marian;
using namespace keywords; using namespace keywords;
/*
Expr x = input(shape={whatevs, 784}, name="X"); Expr x = input(shape={whatevs, 784}, name="X");
Expr y = input(shape={whatevs, 10}, name="Y"); Expr y = input(shape={whatevs, 10}, name="Y");
@ -27,6 +27,13 @@ int main(int argc, char** argv) {
Tensor tx({500, 784}, 1); Tensor tx({500, 784}, 1);
Tensor ty({500, 10}, 1); Tensor ty({500, 10}, 1);
int numImg, imgSize;
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg, imgSize);
vector<int> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");
tx.Load(images);
//ty.Load(labels);
cerr << "tx=" << tx.Debug() << endl; cerr << "tx=" << tx.Debug() << endl;
cerr << "ty=" << ty.Debug() << endl; cerr << "ty=" << ty.Debug() << endl;
@ -54,8 +61,10 @@ int main(int argc, char** argv) {
graph.backward(); graph.backward();
//std::cerr << graph["pred"].val()[0] << std::endl; //std::cerr << graph["pred"].val()[0] << std::endl;
*/
// XOR
/*
Expr x = input(shape={whatevs, 2}, name="X"); Expr x = input(shape={whatevs, 2}, name="X");
Expr y = input(shape={whatevs, 2}, name="Y"); Expr y = input(shape={whatevs, 2}, name="Y");
@ -76,6 +85,7 @@ int main(int argc, char** argv) {
tx.Load("../examples/xor/train.txt"); tx.Load("../examples/xor/train.txt");
ty.Load("../examples/xor/label.txt"); ty.Load("../examples/xor/label.txt");
*/
#if 0 #if 0
hook0(graph); hook0(graph);