From 6935334b64bcbd351b883d7a010742225d1b86ab Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Tue, 13 Sep 2016 18:33:44 +0200 Subject: [PATCH] xor --- src/test.cu | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/test.cu b/src/test.cu index c2b0d62e..ed43f052 100644 --- a/src/test.cu +++ b/src/test.cu @@ -12,6 +12,7 @@ int main(int argc, char** argv) { using namespace marian; using namespace keywords; + /* Expr x = input(shape={whatevs, 784}, name="X"); Expr y = input(shape={whatevs, 10}, name="Y"); @@ -53,7 +54,29 @@ int main(int argc, char** argv) { graph.backward(); //std::cerr << graph["pred"].val()[0] << std::endl; + */ + Expr x = input(shape={whatevs, 2}, name="X"); + Expr y = input(shape={whatevs, 2}, name="Y"); + + Expr w = param(shape={2, 1}, name="W0"); + Expr b = param(shape={1, 1}, name="b0"); + + Expr n5 = dot(x, w); + Expr n6 = n5 + b; + Expr lr = softmax(n6, axis=1, name="pred"); + cerr << "lr=" << lr.Debug() << endl; + + Expr graph = -mean(sum(y * log(lr), axis=1), axis=0, name="cost"); + + Tensor tx({4, 2}, 1); + Tensor ty({4, 1}, 1); + cerr << "tx=" << tx.Debug() << endl; + cerr << "ty=" << ty.Debug() << endl; + + tx.Load("../examples/xor/train.txt"); + ty.Load("../examples/xor/label.txt"); + #if 0 hook0(graph); graph.autodiff();