mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
Merge branch 'master' of https://github.com/emjotde/marian
This commit is contained in:
commit
d62b899ed2
@ -85,6 +85,11 @@
|
|||||||
<type>1</type>
|
<type>1</type>
|
||||||
<locationURI>PARENT-1-PROJECT_LOC/src/marian.h</locationURI>
|
<locationURI>PARENT-1-PROJECT_LOC/src/marian.h</locationURI>
|
||||||
</link>
|
</link>
|
||||||
|
<link>
|
||||||
|
<name>mnist.h</name>
|
||||||
|
<type>1</type>
|
||||||
|
<locationURI>PARENT-1-PROJECT_LOC/src/mnist.h</locationURI>
|
||||||
|
</link>
|
||||||
<link>
|
<link>
|
||||||
<name>tensor.cu</name>
|
<name>tensor.cu</name>
|
||||||
<type>1</type>
|
<type>1</type>
|
||||||
|
@ -59,23 +59,34 @@ inline std::vector<T> Tokenize( const std::string &input
|
|||||||
|
|
||||||
void Tensor::Load(const std::string &path)
|
void Tensor::Load(const std::string &path)
|
||||||
{
|
{
|
||||||
|
size_t totSize = std::accumulate(pimpl_->shape().begin(), pimpl_->shape().end(),
|
||||||
|
1, std::multiplies<int>());
|
||||||
|
cerr << "totSize=" << totSize << endl;
|
||||||
|
std::vector<float> hostData(totSize);
|
||||||
|
|
||||||
fstream strm;
|
fstream strm;
|
||||||
strm.open(path.c_str());
|
strm.open(path.c_str());
|
||||||
|
|
||||||
size_t lineNum = 0;
|
|
||||||
string line;
|
string line;
|
||||||
|
size_t ind = 0;
|
||||||
while ( getline (strm, line) )
|
while ( getline (strm, line) )
|
||||||
{
|
{
|
||||||
cerr << line << '\n';
|
cerr << line << '\n';
|
||||||
vector<Float> toks = Tokenize<Float>(line);
|
vector<Float> toks = Tokenize<Float>(line);
|
||||||
for (size_t i = 0; i < toks.size(); ++i) {
|
for (size_t i = 0; i < toks.size(); ++i) {
|
||||||
pimpl_->set(toks[i], lineNum, i);
|
hostData[ind] = toks[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
++lineNum;
|
++ind;
|
||||||
}
|
}
|
||||||
strm.close();
|
strm.close();
|
||||||
|
|
||||||
|
Load(hostData);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tensor::Load(const std::vector<float> &values)
|
||||||
|
{
|
||||||
|
pimpl_->set(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
11
src/tensor.h
11
src/tensor.h
@ -152,10 +152,12 @@ class TensorImpl {
|
|||||||
thrust::fill(data_.begin(), data_.end(), value);
|
thrust::fill(data_.begin(), data_.end(), value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void set(value_type value, size_t x, size_t y) {
|
void set(const std::vector<Float> &values) {
|
||||||
assert(shape().size() == 2);
|
size_t totSize = std::accumulate(shape().begin(), shape().end(),
|
||||||
size_t sizeRow = sizeof(Float) * shape()[1];
|
1, std::multiplies<int>());
|
||||||
data_[x + sizeRow * y] = value;
|
std::cerr << "totSize=" << totSize << " " << values.size() << std::endl;
|
||||||
|
assert(totSize == values.size());
|
||||||
|
thrust::copy(values.begin(), values.end(), data_.begin());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string Debug() const
|
std::string Debug() const
|
||||||
@ -247,6 +249,7 @@ class Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Load(const std::string &path);
|
void Load(const std::string &path);
|
||||||
|
void Load(const std::vector<float> &values);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
14
src/test.cu
14
src/test.cu
@ -12,7 +12,7 @@ int main(int argc, char** argv) {
|
|||||||
using namespace marian;
|
using namespace marian;
|
||||||
using namespace keywords;
|
using namespace keywords;
|
||||||
|
|
||||||
/*
|
|
||||||
Expr x = input(shape={whatevs, 784}, name="X");
|
Expr x = input(shape={whatevs, 784}, name="X");
|
||||||
Expr y = input(shape={whatevs, 10}, name="Y");
|
Expr y = input(shape={whatevs, 10}, name="Y");
|
||||||
|
|
||||||
@ -27,6 +27,13 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
Tensor tx({500, 784}, 1);
|
Tensor tx({500, 784}, 1);
|
||||||
Tensor ty({500, 10}, 1);
|
Tensor ty({500, 10}, 1);
|
||||||
|
|
||||||
|
int numImg, imgSize;
|
||||||
|
vector<float> images = datasets::mnist::ReadImages("../examples/mnist/t10k-images-idx3-ubyte", numImg, imgSize);
|
||||||
|
vector<int> labels = datasets::mnist::ReadLabels("../examples/mnist/t10k-labels-idx1-ubyte");
|
||||||
|
tx.Load(images);
|
||||||
|
//ty.Load(labels);
|
||||||
|
|
||||||
cerr << "tx=" << tx.Debug() << endl;
|
cerr << "tx=" << tx.Debug() << endl;
|
||||||
cerr << "ty=" << ty.Debug() << endl;
|
cerr << "ty=" << ty.Debug() << endl;
|
||||||
|
|
||||||
@ -54,8 +61,10 @@ int main(int argc, char** argv) {
|
|||||||
graph.backward();
|
graph.backward();
|
||||||
|
|
||||||
//std::cerr << graph["pred"].val()[0] << std::endl;
|
//std::cerr << graph["pred"].val()[0] << std::endl;
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
|
// XOR
|
||||||
|
/*
|
||||||
Expr x = input(shape={whatevs, 2}, name="X");
|
Expr x = input(shape={whatevs, 2}, name="X");
|
||||||
Expr y = input(shape={whatevs, 2}, name="Y");
|
Expr y = input(shape={whatevs, 2}, name="Y");
|
||||||
|
|
||||||
@ -76,6 +85,7 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
tx.Load("../examples/xor/train.txt");
|
tx.Load("../examples/xor/train.txt");
|
||||||
ty.Load("../examples/xor/label.txt");
|
ty.Load("../examples/xor/label.txt");
|
||||||
|
*/
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
hook0(graph);
|
hook0(graph);
|
||||||
|
Loading…
Reference in New Issue
Block a user