mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-05 01:31:46 +03:00
split unary & binary operators into their own file
This commit is contained in:
parent
8b55721206
commit
2a9d3de35a
@ -22,6 +22,8 @@
|
||||
// SOFTWARE.
|
||||
|
||||
#include "node.h"
|
||||
#include "node_operators_unary.h"
|
||||
#include "node_operators_binary.h"
|
||||
#include "tensor_operators.h"
|
||||
|
||||
namespace marian {
|
||||
@ -107,435 +109,8 @@ struct ParamNode : public Node {
|
||||
bool initialized_;
|
||||
};
|
||||
|
||||
struct UnaryNodeOp : public Node {
|
||||
ChainPtr a_;
|
||||
|
||||
template <typename ...Args>
|
||||
UnaryNodeOp(ChainPtr a, Args ...args)
|
||||
: Node(keywords::shape=a->shape(), //@TODO: Check keywords?
|
||||
args...), a_(a) {}
|
||||
|
||||
void backward_numeric() {
|
||||
backward();
|
||||
}
|
||||
};
|
||||
|
||||
struct LogitNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
LogitNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Sigma(_2),
|
||||
val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * _3 * (1 - _3),
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
void check() {
|
||||
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct TanhNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
TanhNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Tanh(_2),
|
||||
val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * (1 - _3 * _3),
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
SoftmaxNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
thrust::copy(a_->val().begin(), a_->val().end(), val_.begin());
|
||||
// Safe version of softmax.
|
||||
Softmax(&val_);
|
||||
}
|
||||
|
||||
void backward() {
|
||||
// For each row, the Jacobian times vector is given by:
|
||||
// J * dy = p .* (dy - avg*1)
|
||||
// where avg = p'*dy and p is the softmax output (probabilities).
|
||||
//
|
||||
// For more information, see sec. 2.5 of the following reference:
|
||||
// André F. T. Martins and Ramon Astudillo.
|
||||
// "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label
|
||||
// Classification." ICML 2016.
|
||||
// http://jmlr.org/proceedings/papers/v48/martins16.pdf
|
||||
|
||||
SoftmaxGrad(a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
};
|
||||
|
||||
struct ArgmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ArgmaxNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a), args...) { }
|
||||
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
Argmax(&val_, &a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
}
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
shape[1] = 1;
|
||||
return shape;
|
||||
}
|
||||
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct LogNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
LogNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) {}
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Log(_2), val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * 1.f / _3,
|
||||
a_->grad(), adj_, a_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct ExpNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ExpNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Exp(_2), val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * Exp(_3),
|
||||
a_->grad(), adj_, a_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct NegNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
NegNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = -_2, val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += -_2, a_->grad(), adj_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
/******************************************************/
|
||||
|
||||
struct BinaryNodeOp : public Node {
|
||||
ChainPtr a_;
|
||||
ChainPtr b_;
|
||||
|
||||
template <typename ...Args>
|
||||
BinaryNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: Node(args...), a_(a), b_(b) {}
|
||||
|
||||
void backward_numeric() {
|
||||
backward();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/*** Matrix Product ***/
|
||||
|
||||
struct DotNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
DotNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b,
|
||||
keywords::shape=newShape(a, b),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a, ChainPtr b) {
|
||||
Shape shape1 = a->shape();
|
||||
Shape shape2 = b->shape();
|
||||
UTIL_THROW_IF2(shape1[1] != shape2[0],
|
||||
"matrix product requires dimensions to match");
|
||||
shape1[1] = shape2[1];
|
||||
return shape1;
|
||||
}
|
||||
|
||||
void forward() {
|
||||
// C = A*B
|
||||
Prod(val_, a_->val(), b_->val(), false, false);
|
||||
}
|
||||
|
||||
void backward() {
|
||||
// D is the adjoint, the matrix of derivatives
|
||||
// df/dA += D*B.T
|
||||
// df/dB += A.T*D
|
||||
// beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
|
||||
// to sum gradients from different graph parts
|
||||
Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
|
||||
Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct PlusNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
PlusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 + _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2,
|
||||
a_->grad(), adj_);
|
||||
Element(_1 += _2,
|
||||
b_->grad(), adj_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct MinusNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
MinusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 - _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2,
|
||||
a_->grad(), adj_);
|
||||
Element(_1 -= _2,
|
||||
b_->grad(), adj_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct MultNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
MultNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 * _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * _3,
|
||||
a_->grad(), adj_, b_->val());
|
||||
Element(_1 += _2 * _3,
|
||||
b_->grad(), adj_, a_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"•\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct DivNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
DivNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 / _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * 1.0f / _3,
|
||||
a_->grad(), adj_, b_->val());
|
||||
Element(_1 -= _2 * _3 / (_4 * _4),
|
||||
b_->grad(), adj_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"÷\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise.
|
||||
struct CrossEntropyNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
CrossEntropyNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b,
|
||||
keywords::shape=newShape(a, b),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a, ChainPtr b) {
|
||||
Shape shape1 = a->shape();
|
||||
Shape shape2 = b->shape();
|
||||
UTIL_THROW_IF2(shape1[0] != shape2[0] || shape1[1] != shape2[1],
|
||||
"cross entropy requires dimensions to match");
|
||||
shape1[1] = 1;
|
||||
return shape1;
|
||||
}
|
||||
|
||||
// We're caching the softmax probabilities here because we'll need them for
|
||||
// the backward computation.
|
||||
void forward() {
|
||||
// C = -dot(B, log(softmax(A))).
|
||||
if (probs_) {
|
||||
probs_.set(0.0);
|
||||
} else {
|
||||
probs_.allocate(a_->val().shape(), 0.0);
|
||||
}
|
||||
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
|
||||
Softmax(&probs_); // Safe version of softmax.
|
||||
Tensor result(a_->val().shape());
|
||||
Element(_1 = -_2 * Log(_3), result, b_->val(), probs_);
|
||||
SumRowwise(result, val_);
|
||||
}
|
||||
|
||||
// @TODO: In most cases it's wasteful to compute the derivative with respect
|
||||
// to the second input which is typically an input node in the computation
|
||||
// graph. In general the backward functions can skip the computation of
|
||||
// gradients wrt input nodes.
|
||||
void backward() {
|
||||
// For each row, the first input derivative is given by adj * (p - y),
|
||||
// where y is the gold label distribution (e.g. one hot vector) and
|
||||
// p is the softmax output (probabilities).
|
||||
// The second input derivative is -adj*log(p).
|
||||
Tensor result(probs_.shape());
|
||||
|
||||
// Compute first input derivative.
|
||||
Element(_1 = _2 - _3, result, probs_, b_->val());
|
||||
ScaleRowwise(result, adj_);
|
||||
Element(_1 += _2, a_->grad(), result);
|
||||
|
||||
// Compute second input derivative.
|
||||
Element(_1 = -Log(_2), result, probs_); // @TODO: use a cached log here.
|
||||
ScaleRowwise(result, adj_);
|
||||
Element(_1 += _2, b_->grad(), result);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"cross_entropy\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
protected:
|
||||
Tensor probs_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
241
src/node_operators_binary.h
Normal file
241
src/node_operators_binary.h
Normal file
@ -0,0 +1,241 @@
|
||||
#include "node.h"
|
||||
#include "tensor_operators.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
|
||||
struct BinaryNodeOp : public Node {
|
||||
ChainPtr a_;
|
||||
ChainPtr b_;
|
||||
|
||||
template <typename ...Args>
|
||||
BinaryNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: Node(args...), a_(a), b_(b) {}
|
||||
|
||||
void backward_numeric() {
|
||||
backward();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/*** Matrix Product ***/
|
||||
|
||||
struct DotNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
DotNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b,
|
||||
keywords::shape=newShape(a, b),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a, ChainPtr b) {
|
||||
Shape shape1 = a->shape();
|
||||
Shape shape2 = b->shape();
|
||||
UTIL_THROW_IF2(shape1[1] != shape2[0],
|
||||
"matrix product requires dimensions to match");
|
||||
shape1[1] = shape2[1];
|
||||
return shape1;
|
||||
}
|
||||
|
||||
void forward() {
|
||||
// C = A*B
|
||||
Prod(val_, a_->val(), b_->val(), false, false);
|
||||
}
|
||||
|
||||
void backward() {
|
||||
// D is the adjoint, the matrix of derivatives
|
||||
// df/dA += D*B.T
|
||||
// df/dB += A.T*D
|
||||
// beta set to 1.0 in gemm, C = alpha * dot(A,B) + beta * C
|
||||
// to sum gradients from different graph parts
|
||||
Prod(a_->grad(), adj_, b_->val(), false, true, 1.0);
|
||||
Prod(b_->grad(), a_->val(), adj_, true, false, 1.0);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"×\", style=\"filled\", fillcolor=\"orange\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct PlusNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
PlusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 + _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2,
|
||||
a_->grad(), adj_);
|
||||
Element(_1 += _2,
|
||||
b_->grad(), adj_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"+\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct MinusNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
MinusNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 - _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2,
|
||||
a_->grad(), adj_);
|
||||
Element(_1 -= _2,
|
||||
b_->grad(), adj_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct MultNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
MultNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 * _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * _3,
|
||||
a_->grad(), adj_, b_->val());
|
||||
Element(_1 += _2 * _3,
|
||||
b_->grad(), adj_, a_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"•\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct DivNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
DivNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b, keywords::shape=a->shape(), args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = _2 / _3,
|
||||
val_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * 1.0f / _3,
|
||||
a_->grad(), adj_, b_->val());
|
||||
Element(_1 -= _2 * _3 / (_4 * _4),
|
||||
b_->grad(), adj_, a_->val(), b_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"÷\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
// Cross-entropy node. It computes -b*log(softmax(a)), summing rowwise.
|
||||
struct CrossEntropyNodeOp : public BinaryNodeOp {
|
||||
template <typename ...Args>
|
||||
CrossEntropyNodeOp(ChainPtr a, ChainPtr b, Args ...args)
|
||||
: BinaryNodeOp(a, b,
|
||||
keywords::shape=newShape(a, b),
|
||||
args...) { }
|
||||
|
||||
Shape newShape(ChainPtr a, ChainPtr b) {
|
||||
Shape shape1 = a->shape();
|
||||
Shape shape2 = b->shape();
|
||||
UTIL_THROW_IF2(shape1[0] != shape2[0] || shape1[1] != shape2[1],
|
||||
"cross entropy requires dimensions to match");
|
||||
shape1[1] = 1;
|
||||
return shape1;
|
||||
}
|
||||
|
||||
// We're caching the softmax probabilities here because we'll need them for
|
||||
// the backward computation.
|
||||
void forward() {
|
||||
// C = -dot(B, log(softmax(A))).
|
||||
if (probs_) {
|
||||
probs_.set(0.0);
|
||||
} else {
|
||||
probs_.allocate(a_->val().shape(), 0.0);
|
||||
}
|
||||
thrust::copy(a_->val().begin(), a_->val().end(), probs_.begin());
|
||||
Softmax(&probs_); // Safe version of softmax.
|
||||
Tensor result(a_->val().shape());
|
||||
Element(_1 = -_2 * Log(_3), result, b_->val(), probs_);
|
||||
SumRowwise(result, val_);
|
||||
}
|
||||
|
||||
// @TODO: In most cases it's wasteful to compute the derivative with respect
|
||||
// to the second input which is typically an input node in the computation
|
||||
// graph. In general the backward functions can skip the computation of
|
||||
// gradients wrt input nodes.
|
||||
void backward() {
|
||||
// For each row, the first input derivative is given by adj * (p - y),
|
||||
// where y is the gold label distribution (e.g. one hot vector) and
|
||||
// p is the softmax output (probabilities).
|
||||
// The second input derivative is -adj*log(p).
|
||||
Tensor result(probs_.shape());
|
||||
|
||||
// Compute first input derivative.
|
||||
Element(_1 = _2 - _3, result, probs_, b_->val());
|
||||
ScaleRowwise(result, adj_);
|
||||
Element(_1 += _2, a_->grad(), result);
|
||||
|
||||
// Compute second input derivative.
|
||||
Element(_1 = -Log(_2), result, probs_); // @TODO: use a cached log here.
|
||||
ScaleRowwise(result, adj_);
|
||||
Element(_1 += _2, b_->grad(), result);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"cross_entropy\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl;
|
||||
ss << "\"" << b_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
protected:
|
||||
Tensor probs_;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
203
src/node_operators_unary.h
Normal file
203
src/node_operators_unary.h
Normal file
@ -0,0 +1,203 @@
|
||||
#include "node.h"
|
||||
#include "tensor_operators.h"
|
||||
|
||||
namespace marian {
|
||||
|
||||
struct UnaryNodeOp : public Node {
|
||||
ChainPtr a_;
|
||||
|
||||
template <typename ...Args>
|
||||
UnaryNodeOp(ChainPtr a, Args ...args)
|
||||
: Node(keywords::shape=a->shape(), //@TODO: Check keywords?
|
||||
args...), a_(a) {}
|
||||
|
||||
void backward_numeric() {
|
||||
backward();
|
||||
}
|
||||
};
|
||||
|
||||
struct LogitNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
LogitNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Sigma(_2),
|
||||
val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * _3 * (1 - _3),
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
void check() {
|
||||
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"logit\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct TanhNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
TanhNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Tanh(_2),
|
||||
val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * (1 - _3 * _3),
|
||||
a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"tanh\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct SoftmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
SoftmaxNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
thrust::copy(a_->val().begin(), a_->val().end(), val_.begin());
|
||||
// Safe version of softmax.
|
||||
Softmax(&val_);
|
||||
}
|
||||
|
||||
void backward() {
|
||||
// For each row, the Jacobian times vector is given by:
|
||||
// J * dy = p .* (dy - avg*1)
|
||||
// where avg = p'*dy and p is the softmax output (probabilities).
|
||||
//
|
||||
// For more information, see sec. 2.5 of the following reference:
|
||||
// André F. T. Martins and Ramon Astudillo.
|
||||
// "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label
|
||||
// Classification." ICML 2016.
|
||||
// http://jmlr.org/proceedings/papers/v48/martins16.pdf
|
||||
|
||||
SoftmaxGrad(a_->grad(), adj_, val_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"softmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
};
|
||||
|
||||
struct ArgmaxNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ArgmaxNodeOp(ChainPtr a, Args ...args)
|
||||
: UnaryNodeOp(a, keywords::shape=newShape(a), args...) { }
|
||||
|
||||
void forward() {
|
||||
// B = softmax(A).
|
||||
Argmax(&val_, &a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
}
|
||||
|
||||
Shape newShape(ChainPtr a) {
|
||||
Shape shape = a->shape();
|
||||
shape[1] = 1;
|
||||
return shape;
|
||||
}
|
||||
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"argmax\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct LogNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
LogNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) {}
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Log(_2), val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * 1.f / _3,
|
||||
a_->grad(), adj_, a_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"log\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct ExpNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
ExpNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = Exp(_2), val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += _2 * Exp(_3),
|
||||
a_->grad(), adj_, a_->val());
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"exp\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
struct NegNodeOp : public UnaryNodeOp {
|
||||
template <typename ...Args>
|
||||
NegNodeOp(Args ...args)
|
||||
: UnaryNodeOp(args...) { }
|
||||
|
||||
void forward() {
|
||||
Element(_1 = -_2, val_, a_->val());
|
||||
}
|
||||
|
||||
void backward() {
|
||||
Element(_1 += -_2, a_->grad(), adj_);
|
||||
}
|
||||
|
||||
virtual std::string graphviz() {
|
||||
std::stringstream ss;
|
||||
ss << "\"" << this << "\" [shape=\"box\", label=\"-\", style=\"filled\", fillcolor=\"yellow\"]" << std::endl;
|
||||
ss << "\"" << a_ << "\" -> \"" << this << "\"" << std::endl << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user