move unary and binary operators to their own header files

This commit is contained in:
Hieu Hoang 2016-09-20 12:06:30 +01:00
parent d18d299009
commit 8677f99597
3 changed files with 120 additions and 630 deletions

View File

@ -23,6 +23,8 @@
#include "node.h"
#include "tensor_operators.h"
#include "node_operators_unary.h"
#include "node_operators_binary.h"
namespace marian {
@ -109,524 +111,8 @@ struct ParamNode : public Node {
bool initialized_;
};
struct UnaryNodeOp : public Node {
ChainPtr a_;
template <typename ...Args>
UnaryNodeOp(ChainPtr a, Args ...args)
: Node(keywords::shape=a->shape(), //@TODO: Check keywords?
args...), a_(a) {}
};
struct LogitNodeOp : public UnaryNodeOp {
template <typename ...Args>
LogitNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
Element(_1 = Sigma(_2),
val_, a_->val());
}
void backward() {
Element(_1 += _2 * _3 * (1.0f - _3),
a_->grad(), adj_, val_);
}
void check() {
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("logit")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct TanhNodeOp : public UnaryNodeOp {
template <typename ...Args>
TanhNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
Element(_1 = Tanh(_2),
val_, a_->val());
}
void backward() {
Element(_1 += _2 * (1.0f - (_3 * _3)),
a_->grad(), adj_, val_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("tanh")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct ReLUNodeOp : public UnaryNodeOp {
template <typename ...Args>
ReLUNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
Element(_1 = ReLU(_2),
val_, a_->val());
}
void backward() {
Element(_1 += _2 * ReLUback(_3),
a_->grad(), adj_, a_->val());
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("ReLU")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
// @TODO: slow and probably buggy
struct DropoutNodeOp : public UnaryNodeOp {
template <typename ...Args>
DropoutNodeOp(Args ...args)
: UnaryNodeOp(args...),
p_(0.5), seed_(time(0)) { }
void forward() {
//Element(_1 = Bernoulli(p_, (size_t)this) * _2,
// val_, a_->val())
Dropout(val_, a_->val(), p_, seed_++);
}
void backward() {
Element(_1 += _2 * (_3 != 0.0f), // transform non-zero to 1
a_->grad(), adj_, val_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("dropout")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
private:
float p_;
int seed_;
};
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
// B = softmax(A).
thrust::copy(a_->val().begin(), a_->val().end(), val_.begin());
// Safe version of softmax.
Softmax(&val_);
}
void backward() {
// For each row, the Jacobian times vector is given by:
// J * dy = p .* (dy - avg*1)
// where avg = p'*dy and p is the softmax output (probabilities).
//
// For more information, see sec. 2.5 of the following reference:
// André F. T. Martins and Ramon Astudillo.
// "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label
// Classification." ICML 2016.
// http://jmlr.org/proceedings/papers/v48/martins16.pdf
SoftmaxGrad(a_->grad(), adj_, val_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("softmax")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct ArgmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
ArgmaxNodeOp(ChainPtr a, Args ...args)
: UnaryNodeOp(a, keywords::shape=newShape(a), args...) { }
void forward() {
// B = softmax(A).
Argmax(&val_, &a_->val());
}
void backward() {
}
Shape newShape(ChainPtr a) {
Shape shape = a->shape();
shape[1] = 1;
return shape;
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label="
<< label("argmax") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct LogNodeOp : public UnaryNodeOp {
template <typename ...Args>
LogNodeOp(Args ...args)
: UnaryNodeOp(args...) {}
void forward() {
Element(_1 = Log(_2), val_, a_->val());
}
void backward() {
Element(_1 += _2 * (1.f / _3),
a_->grad(), adj_, a_->val());
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label="
<< label("log") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct ExpNodeOp : public UnaryNodeOp {
template <typename ...Args>
ExpNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
Element(_1 = Exp(_2), val_, a_->val());
}
void backward() {
Element(_1 += _2 * Exp(_3),
a_->grad(), adj_, a_->val());
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("exp")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct NegNodeOp : public UnaryNodeOp {
template <typename ...Args>
NegNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
Element(_1 = -_2, val_, a_->val());
}
void backward() {
Element(_1 += -_2, a_->grad(), adj_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label="
<< label("-") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
/******************************************************/
struct BinaryNodeOp : public Node {
ChainPtr a_;
ChainPtr b_;
template <typename ...Args>
BinaryNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: Node(args...), a_(a), b_(b) {}
};
/*** Matrix Product ***/
struct DotNodeOp : public BinaryNodeOp {
template <typename ...Args>
DotNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b,
keywords::shape=newShape(a, b),
args...) { }
Shape newShape(ChainPtr a, ChainPtr b) {
Shape shape1 = a->shape();
Shape shape2 = b->shape();
UTIL_THROW_IF2(shape1[1] != shape2[0],
"matrix product requires dimensions to match");
shape1[1] = shape2[1];
return shape1;
}
void forward() {
// C = A*B
Prod(val_, a_->val(), b_->val(), false, false);
}
void backward() {
// D is the adjoint, the matrix of derivatives
// df/dA += D*B.T
// df/dB += A.T*D
// beta set to 1.0 in gemm, C = dot(A,B) + beta * C
// to sum gradients from different graph parts
Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("×")
<< ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct PlusNodeOp : public BinaryNodeOp {
template <typename ...Args>
PlusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
void forward() {
Element(_1 = _2 + _3,
val_, a_->val(), b_->val());
}
void backward() {
Element(_1 += _2,
a_->grad(), adj_);
Element(_1 += _2,
b_->grad(), adj_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("+")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct ReLUPlusNodeOp : public BinaryNodeOp {
template <typename ...Args>
ReLUPlusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
void forward() {
Element(_1 = ReLU(_2 + _3),
val_, a_->val(), b_->val());
}
void backward() {
Element(_1 += _2 * ReLUback(_3 + _4),
a_->grad(), adj_, a_->val(), b_->val());
Element(_1 += _2 * ReLUback(_3 + _4),
b_->grad(), adj_, a_->val(), b_->val());
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("ReLU<br/>+")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct MinusNodeOp : public BinaryNodeOp {
template <typename ...Args>
MinusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
void forward() {
Element(_1 = _2 - _3,
val_, a_->val(), b_->val());
}
void backward() {
Element(_1 += _2,
a_->grad(), adj_);
Element(_1 -= _2,
b_->grad(), adj_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("-")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct MultNodeOp : public BinaryNodeOp {
template <typename ...Args>
MultNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
void forward() {
Element(_1 = _2 * _3,
val_, a_->val(), b_->val());
}
void backward() {
Element(_1 += _2 * _3,
a_->grad(), adj_, b_->val());
Element(_1 += _2 * _3,
b_->grad(), adj_, a_->val());
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct DivNodeOp : public BinaryNodeOp {
template <typename ...Args>
DivNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
void forward() {
Element(_1 = _2 / _3,
val_, a_->val(), b_->val());
}
void backward() {
Element(_1 += _2 * 1.0f / _3,
a_->grad(), adj_, b_->val());
Element(_1 -= _2 * _3 / (_4 * _4),
b_->grad(), adj_, a_->val(), b_->val());
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("÷")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise.
struct CrossEntropyNodeOp : public BinaryNodeOp {
template <typename ...Args>
CrossEntropyNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b,
keywords::shape=newShape(a, b),
args...) { }
Shape newShape(ChainPtr a, ChainPtr b) {
Shape shape1 = a->shape();
Shape shape2 = b->shape();
UTIL_THROW_IF2(shape1[0] != shape2[0] || shape1[1] != shape2[1],
"cross entropy requires dimensions to match");
shape1[1] = 1;
return shape1;
}
// We're caching the softmax probabilities here because we'll need them for
// the backward computation.
void forward() {
// C = -dot(B, log(softmax(A))).
if (probs_) {
probs_.set(0.0);
} else {
probs_.allocate(a_->val().shape(), 0.0);
}
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
Softmax(&probs_); // Safe version of softmax.
Tensor result(a_->val().shape());
Element(_1 = -_2 * Log(_3), result, b_->val(), probs_);
SumRowwise(result, val_);
}
// @TODO: In most cases it's wasteful to compute the derivative with respect
// to the second input which is typically an input node in the computation
// graph. In general the backward functions can skip the computation of
// gradients wrt input nodes.
void backward() {
// For each row, the first input derivative is given by adj * (p - y),
// where y is the gold label distribution (e.g. one hot vector) and
// p is the softmax output (probabilities).
// The second input derivative is -adj*log(p).
Tensor result(probs_.shape());
// Compute first input derivative.
Element(_1 = _2 - _3, result, probs_, b_->val());
ScaleRowwise(result, adj_);
Element(_1 += _2, a_->grad(), result);
// Compute second input derivative.
Element(_1 = -Log(_2), result, probs_); // @TODO: use a cached log here.
ScaleRowwise(result, adj_);
Element(_1 += _2, b_->grad(), result);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("x-ent")
<< ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
protected:
Tensor probs_;
};
}

View File

@ -3,7 +3,6 @@
namespace marian {
struct BinaryNodeOp : public Node {
ChainPtr a_;
ChainPtr b_;
@ -11,18 +10,6 @@ struct BinaryNodeOp : public Node {
template <typename ...Args>
BinaryNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: Node(args...), a_(a), b_(b) {}
void backward_numeric(Float delta) {
using namespace std;
backward();
/*
cerr << "BinaryNodeOp::" << typeid(*this).name() << "::backward_numeric" << endl;
cerr << "a_->grad()=" << a_->grad().Debug() << endl;
cerr << "b_->grad()=" << b_->grad().Debug() << endl;
cerr << "adj_=" << adj_.Debug() << endl;
*/
}
};
/*** Matrix Product ***/
@ -52,7 +39,7 @@ struct DotNodeOp : public BinaryNodeOp {
// D is the adjoint, the matrix of derivatives
// df/dA += D*B.T
// df/dB += A.T*D
// beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
// beta set to 1.0 in gemm, C = dot(A,B) + beta * C
// to sum gradients from different graph parts
Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
@ -60,7 +47,8 @@ struct DotNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("×")
<< ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
@ -87,7 +75,36 @@ struct PlusNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("+")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct ReLUPlusNodeOp : public BinaryNodeOp {
template <typename ...Args>
ReLUPlusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
void forward() {
Element(_1 = ReLU(_2 + _3),
val_, a_->val(), b_->val());
}
void backward() {
Element(_1 += _2 * ReLUback(_3 + _4),
a_->grad(), adj_, a_->val(), b_->val());
Element(_1 += _2 * ReLUback(_3 + _4),
b_->grad(), adj_, a_->val(), b_->val());
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("ReLU<br/>+")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
@ -114,7 +131,8 @@ struct MinusNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("-")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
@ -141,7 +159,8 @@ struct MultNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
@ -168,7 +187,8 @@ struct DivNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"÷\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("÷")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
@ -233,8 +253,9 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"cross_entropy\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("x-ent")
<< ", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -244,5 +265,6 @@ struct CrossEntropyNodeOp : public BinaryNodeOp {
};
}

View File

@ -10,89 +10,6 @@ struct UnaryNodeOp : public Node {
UnaryNodeOp(ChainPtr a, Args ...args)
: Node(keywords::shape=a->shape(), //@TODO: Check keywords?
args...), a_(a) {}
void backward_numeric(Float delta) {
using namespace std;
cerr << "UnaryNodeOp::" << typeid(*this).name() << "::backward_numeric()" << endl;
Tensor input = a_->val();
size_t totSize = GetTotalSize(input.shape());
std::vector<float> preCalcGrad(totSize);
thrust::copy(a_->grad().begin(), a_->grad().end(), preCalcGrad.begin());
output("preCalcGrad", preCalcGrad);
// use df/dx to calc grad
backward();
//cerr << "orig a_->grad()=" << a_->grad().Debug() << endl;
std::vector<float> diffGrad(totSize);
thrust::copy(a_->grad().begin(), a_->grad().end(), diffGrad.begin());
output("diffGrad", diffGrad);
// reset grad
thrust::copy(preCalcGrad.begin(), preCalcGrad.end(), a_->grad().begin());
//cerr << "reset a_->grad()=" << a_->grad().Debug() << endl;
// START CALC of numerical gradient
// new values
input.incr(delta);
forward();
//cerr << "input=" << input.Debug() << endl;
//cerr << "val_=" << val_.Debug() << endl;
std::vector<float> newVal(totSize);
thrust::copy(val_.begin(), val_.end(), newVal.begin());
//output("newVal", newVal);
// old values
input.incr(-delta);
forward();
//cerr << "input=" << input.Debug() << endl;
//cerr << "val_=" << val_.Debug() << endl;
std::vector<float> origVal(totSize);
thrust::copy(val_.begin(), val_.end(), origVal.begin());
//output("origVal", origVal);
// calc gradient
//cerr << "adj_=" << adj_.Debug() << endl;
std::vector<float> adjVec(totSize);
thrust::copy(adj_.begin(), adj_.end(), adjVec.begin());
std::vector<float> numericalGrad(totSize);
for (size_t i = 0; i < totSize; ++i) {
numericalGrad[i] = preCalcGrad[i] + (adjVec[i] * (newVal[i] - origVal[i]) / delta);
}
output("numericalGrad", numericalGrad);
//cerr << "numeric a_->grad()=" << a_->grad().Debug() << endl;
// set grad results
thrust::copy(numericalGrad.begin(), numericalGrad.end(), a_->grad().begin());
// print out diff between diffGrad and numericalGrad
std::vector<float> origGrad(totSize);
std::vector<float> diff(totSize);
thrust::copy(a_->grad().begin(), a_->grad().end(), origGrad.begin());
for (size_t i = 0; i < totSize; ++i) {
diff[i] = (diffGrad[i] - numericalGrad[i]) / delta;
}
output("diff", diff);
}
void output(const std::string &title, const std::vector<float> &vec)
{
std::cerr << title << " " << vec.size() << ":";
for (size_t i = 0; i < vec.size(); ++i) {
std::cerr << vec[i] << " ";
}
std::cerr << std::endl;
}
};
struct LogitNodeOp : public UnaryNodeOp {
@ -106,7 +23,7 @@ struct LogitNodeOp : public UnaryNodeOp {
}
void backward() {
Element(_1 += _2 * _3 * (1 - _3),
Element(_1 += _2 * _3 * (1.0f - _3),
a_->grad(), adj_, val_);
}
@ -116,7 +33,8 @@ struct LogitNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("logit")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -134,19 +52,77 @@ struct TanhNodeOp : public UnaryNodeOp {
}
void backward() {
Element(_1 += _2 * (1 - _3 * _3),
Element(_1 += _2 * (1.0f - (_3 * _3)),
a_->grad(), adj_, val_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("tanh")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
struct ReLUNodeOp : public UnaryNodeOp {
template <typename ...Args>
ReLUNodeOp(Args ...args)
: UnaryNodeOp(args...) { }
void forward() {
Element(_1 = ReLU(_2),
val_, a_->val());
}
void backward() {
Element(_1 += _2 * ReLUback(_3),
a_->grad(), adj_, a_->val());
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("ReLU")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
// @TODO: slow and probably buggy
struct DropoutNodeOp : public UnaryNodeOp {
template <typename ...Args>
DropoutNodeOp(Args ...args)
: UnaryNodeOp(args...),
p_(0.5), seed_(time(0)) { }
void forward() {
//Element(_1 = Bernoulli(p_, (size_t)this) * _2,
// val_, a_->val())
Dropout(val_, a_->val(), p_, seed_++);
}
void backward() {
Element(_1 += _2 * (_3 != 0.0f), // transform non-zero to 1
a_->grad(), adj_, val_);
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("dropout")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
private:
float p_;
int seed_;
};
struct SoftmaxNodeOp : public UnaryNodeOp {
template <typename ...Args>
SoftmaxNodeOp(Args ...args)
@ -175,7 +151,8 @@ struct SoftmaxNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("softmax")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -203,7 +180,8 @@ struct ArgmaxNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label="
<< label("argmax") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -220,13 +198,14 @@ struct LogNodeOp : public UnaryNodeOp {
}
void backward() {
Element(_1 += _2 * 1.f / _3,
Element(_1 += _2 * (1.f / _3),
a_->grad(), adj_, a_->val());
}
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label="
<< label("log") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -249,7 +228,8 @@ struct ExpNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label=" << label("exp")
<< ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
@ -271,12 +251,14 @@ struct NegNodeOp : public UnaryNodeOp {
virtual std::string graphviz() {
std::stringstream ss;
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << this << "\" [shape=\"box\", label="
<< label("-") << ", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
return ss.str();
};
};
}