set value in tensor

This commit is contained in:
Hieu Hoang 2016-09-13 18:55:48 +02:00
parent a8a664ca36
commit 8dcdf8f28a
2 changed files with 11 additions and 7 deletions

View File

@ -59,23 +59,29 @@ inline std::vector<T> Tokenize( const std::string &input
void Tensor::Load(const std::string &path) void Tensor::Load(const std::string &path)
{ {
size_t totSize = std::accumulate(pimpl_->shape().begin(), pimpl_->shape().end(),
1, std::multiplies<int>());
cerr << "totSize=" << totSize << endl;
std::vector<float> hostData(totSize);
fstream strm; fstream strm;
strm.open(path.c_str()); strm.open(path.c_str());
size_t lineNum = 0;
string line; string line;
size_t ind = 0;
while ( getline (strm, line) ) while ( getline (strm, line) )
{ {
cerr << line << '\n'; cerr << line << '\n';
vector<Float> toks = Tokenize<Float>(line); vector<Float> toks = Tokenize<Float>(line);
for (size_t i = 0; i < toks.size(); ++i) { for (size_t i = 0; i < toks.size(); ++i) {
pimpl_->set(toks[i], lineNum, i); hostData[ind] = toks[i];
} }
++lineNum; ++ind;
} }
strm.close(); strm.close();
} }
} }

View File

@ -152,10 +152,8 @@ class TensorImpl {
thrust::fill(data_.begin(), data_.end(), value); thrust::fill(data_.begin(), data_.end(), value);
} }
void set(value_type value, size_t x, size_t y) { void set(const std::vector<Float> &values) {
assert(shape().size() == 2); thrust::copy(values.begin(), values.end(), data_.begin());
size_t sizeRow = sizeof(Float) * shape()[1];
data_[x + sizeRow * y] = value;
} }
std::string Debug() const std::string Debug() const