This commit is contained in:
Hieu Hoang 2016-09-13 15:19:39 +02:00
parent 7b9555b1de
commit 08055a2662
2 changed files with 24 additions and 2 deletions

View File

@ -5,6 +5,7 @@
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/functional.h> #include <thrust/functional.h>
#include <numeric> #include <numeric>
#include <sstream>
#include "definitions.h" #include "definitions.h"
#include "exception.h" #include "exception.h"
@ -139,6 +140,17 @@ class TensorImpl {
void set(value_type value) { void set(value_type value) {
thrust::fill(data_.begin(), data_.end(), value); thrust::fill(data_.begin(), data_.end(), value);
} }
std::string Debug() const
{
std::stringstream strm;
assert(shape_.size());
strm << "shape=" << shape_[0];
for (size_t i = 1; i < shape_.size(); ++i) {
strm << "x" << shape_[i];
}
return strm.str();
}
}; };
template <typename Type> template <typename Type>
@ -214,6 +226,12 @@ class Tensor {
operator bool() { operator bool() {
return pimpl_ != nullptr; return pimpl_ != nullptr;
} }
std::string Debug() const
{
return pimpl_->Debug();
}
}; };
} }

View File

@ -1,6 +1,8 @@
#include "marian.h" #include "marian.h"
using namespace std;
int main(int argc, char** argv) { int main(int argc, char** argv) {
using namespace marian; using namespace marian;
@ -17,6 +19,8 @@ int main(int argc, char** argv) {
Tensor tx({500, 784}, 1); Tensor tx({500, 784}, 1);
Tensor ty({500, 10}, 1); Tensor ty({500, 10}, 1);
cerr << "tx=" << tx.Debug();
cerr << "ty=" << ty.Debug();
x = tx; x = tx;
y = ty; y = ty;