mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-03 20:13:47 +03:00
xor
This commit is contained in:
parent
8f25a1d4bf
commit
6935334b64
23
src/test.cu
23
src/test.cu
@ -12,6 +12,7 @@ int main(int argc, char** argv) {
|
||||
using namespace marian;
|
||||
using namespace keywords;
|
||||
|
||||
/*
|
||||
Expr x = input(shape={whatevs, 784}, name="X");
|
||||
Expr y = input(shape={whatevs, 10}, name="Y");
|
||||
|
||||
@ -53,6 +54,28 @@ int main(int argc, char** argv) {
|
||||
graph.backward();
|
||||
|
||||
//std::cerr << graph["pred"].val()[0] << std::endl;
|
||||
*/
|
||||
|
||||
Expr x = input(shape={whatevs, 2}, name="X");
|
||||
Expr y = input(shape={whatevs, 2}, name="Y");
|
||||
|
||||
Expr w = param(shape={2, 1}, name="W0");
|
||||
Expr b = param(shape={1, 1}, name="b0");
|
||||
|
||||
Expr n5 = dot(x, w);
|
||||
Expr n6 = n5 + b;
|
||||
Expr lr = softmax(n6, axis=1, name="pred");
|
||||
cerr << "lr=" << lr.Debug() << endl;
|
||||
|
||||
Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost");
|
||||
|
||||
Tensor tx({4, 2}, 1);
|
||||
Tensor ty({4, 1}, 1);
|
||||
cerr << "tx=" << tx.Debug() << endl;
|
||||
cerr << "ty=" << ty.Debug() << endl;
|
||||
|
||||
tx.Load("../examples/xor/train.txt");
|
||||
ty.Load("../examples/xor/label.txt");
|
||||
|
||||
#if 0
|
||||
hook0(graph);
|
||||
|
Loading…
Reference in New Issue
Block a user