This commit is contained in:
Roman Grundkiewicz 2016-09-13 16:12:27 +02:00
commit 12f1f47842
3 changed files with 32 additions and 10 deletions

View File

@ -6,7 +6,6 @@ with operator overloading.
In honour of Marian Rejewski, a Polish mathematician and
cryptologist.
Installation
------------
@ -25,6 +24,7 @@ Exporting some paths for CuDNN may be required (put it, for example, in your `.b
Compilation with `cmake > 3.5`:
mkdir build
cd build
cmake ..
make -j

View File

@ -5,6 +5,7 @@
#include <thrust/device_vector.h>
#include <thrust/functional.h>
#include <numeric>
#include <sstream>
#include "definitions.h"
#include "exception.h"
@ -139,6 +140,17 @@ class TensorImpl {
void set(value_type value) {
thrust::fill(data_.begin(), data_.end(), value);
}
std::string Debug() const
{
std::stringstream strm;
assert(shape_.size());
strm << "shape=" << shape_[0];
for (size_t i = 1; i < shape_.size(); ++i) {
strm << "x" << shape_[i];
}
return strm.str();
}
};
template <typename Type>
@ -214,6 +226,12 @@ class Tensor {
operator bool() {
return pimpl_ != nullptr;
}
std::string Debug() const
{
return pimpl_->Debug();
}
};
}
}

View File

@ -1,23 +1,27 @@
#include "marian.h"
using namespace std;
int main(int argc, char** argv) {
using namespace marian;
using namespace keywords;
auto x = input(shape={whatevs, 784}, name="X");
auto y = input(shape={whatevs, 10}, name="Y");
Expr x = input(shape={whatevs, 784}, name="X");
Expr y = input(shape={whatevs, 10}, name="Y");
auto w = param(shape={784, 10}, name="W0");
auto b = param(shape={1, 10}, name="b0");
Expr w = param(shape={784, 10}, name="W0");
Expr b = param(shape={1, 10}, name="b0");
auto lr = softmax(dot(x, w) + b, axis=1, name="pred");
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
Expr lr = softmax(dot(x, w) + b, axis=1, name="pred");
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
Tensor tx({500, 784}, 1);
Tensor ty({500, 10}, 1);
cerr << "tx=" << tx.Debug();
cerr << "ty=" << ty.Debug();
x = tx;
y = ty;
@ -46,4 +50,4 @@ int main(int argc, char** argv) {
//opt.run();
return 0;
}
}