mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
Merge branch 'master' of http://github.com/emjotde/Marian
This commit is contained in:
commit
12f1f47842
@ -6,7 +6,6 @@ with operator overloading.
|
||||
In honour of Marian Rejewski, a Polish mathematician and
|
||||
cryptologist.
|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
@ -25,6 +24,7 @@ Exporting some paths for CuDNN may be required (put it, for example, in your `.b
|
||||
Compilation with `cmake > 3.5`:
|
||||
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ..
|
||||
make -j
|
||||
|
||||
|
18
src/tensor.h
18
src/tensor.h
@ -5,6 +5,7 @@
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "definitions.h"
|
||||
#include "exception.h"
|
||||
@ -139,6 +140,17 @@ class TensorImpl {
|
||||
void set(value_type value) {
|
||||
thrust::fill(data_.begin(), data_.end(), value);
|
||||
}
|
||||
|
||||
std::string Debug() const
|
||||
{
|
||||
std::stringstream strm;
|
||||
assert(shape_.size());
|
||||
strm << "shape=" << shape_[0];
|
||||
for (size_t i = 1; i < shape_.size(); ++i) {
|
||||
strm << "x" << shape_[i];
|
||||
}
|
||||
return strm.str();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
@ -214,6 +226,12 @@ class Tensor {
|
||||
operator bool() {
|
||||
return pimpl_ != nullptr;
|
||||
}
|
||||
|
||||
std::string Debug() const
|
||||
{
|
||||
return pimpl_->Debug();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
16
src/test.cu
16
src/test.cu
@ -1,22 +1,26 @@
|
||||
|
||||
#include "marian.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
auto x = input(shape={whatevs, 784}, name="X");
|
||||
auto y = input(shape={whatevs, 10}, name="Y");
|
||||
Expr x = input(shape={whatevs, 784}, name="X");
|
||||
Expr y = input(shape={whatevs, 10}, name="Y");
|
||||
|
||||
auto w = param(shape={784, 10}, name="W0");
|
||||
auto b = param(shape={1, 10}, name="b0");
|
||||
Expr w = param(shape={784, 10}, name="W0");
|
||||
Expr b = param(shape={1, 10}, name="b0");
|
||||
|
||||
auto lr = softmax(dot(x, w) + b, axis=1, name="pred");
|
||||
auto graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
Expr lr = softmax(dot(x, w) + b, axis=1, name="pred");
|
||||
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
|
||||
Tensor tx({500, 784}, 1);
|
||||
Tensor ty({500, 10}, 1);
|
||||
cerr << "tx=" << tx.Debug();
|
||||
cerr << "ty=" << ty.Debug();
|
||||
|
||||
x = tx;
|
||||
y = ty;
|
||||
|
Loading…
Reference in New Issue
Block a user