From 28ba2628f9e8a684113a53e1eabb05d0a190fffb Mon Sep 17 00:00:00 2001 From: Marcin Junczys-Dowmunt Date: Wed, 4 May 2016 11:07:44 +0100 Subject: [PATCH] more operators --- README.md | 6 ++-- src/marian.h | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/test.cpp | 2 +- 3 files changed, 85 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8804e187..0fff5e0f 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Marian* -Parallel Automatic Differentiation Library +A C++ gpu-specific parallel automatic differentiation library +with operator overloading. -`*` = in honour of Marian Rejewski, a Polish mathematician and cryptologist. +'*' = in honour of Marian Rejewski, a Polish mathematician and +cryptologist. diff --git a/src/marian.h b/src/marian.h index 4a14e340..18d561c4 100644 --- a/src/marian.h +++ b/src/marian.h @@ -89,6 +89,8 @@ class Var { VimplPtr vimpl_; }; +/////////////////////////////////////////////////// + struct OpVimpl : public Vimpl { OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { } @@ -108,6 +110,44 @@ inline Var log(const Var& a) { return Var(VimplPtr(new LogVimpl(a.vimpl()))); } +struct ExpVimpl : public OpVimpl { + ExpVimpl(VimplPtr a) : OpVimpl(std::exp(a->val()), a) { } + + void chain() { + a_->grad() += adj_ * std::exp(a_->val()); + } +}; + +inline Var exp(const Var& a) { + return Var(VimplPtr(new ExpVimpl(a.vimpl()))); +} + +struct NegVimpl : public OpVimpl { + NegVimpl(VimplPtr a) : OpVimpl(-a->val(), a) { } + + void chain() { + a_->grad() -= adj_; + } +}; + +inline Var operator-(const Var& a) { + return Var(VimplPtr(new NegVimpl(a.vimpl()))); +} + +// @TODO: take care of large exponents +struct SigmaVimpl : public OpVimpl { + SigmaVimpl(VimplPtr a) : OpVimpl(1.f / (1.f + std::exp(-a->val())), a) { } + + void chain() { + Tensor l = 1.f / (1.f + std::exp(-a_->val())); + a_->grad() += adj_ * l * (1 - l); + } +}; + +inline Var sigma(const Var& a) { + return Var(VimplPtr(new SigmaVimpl(a.vimpl()))); +} + /////////////////////////////////////////////////// @@ -132,4 +172,44 @@ inline Var operator+(const Var& a, const Var& b) { return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl()))); } +struct MinusVimplVV : public OpVimplVV { + MinusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() - b->val(), a, b) { } + + void chain() { + a_->grad() -= adj_; + b_->grad() -= adj_; + } +}; + +inline Var operator-(const Var& a, const Var& b) { + return Var(VimplPtr(new MinusVimplVV(a.vimpl(), b.vimpl()))); +} + +struct MultVimplVV : public OpVimplVV { + MultVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() * b->val(), a, b) { } + + void chain() { + a_->grad() += adj_ * b_->val(); + b_->grad() += adj_ * a_->val(); + } +}; + +inline Var operator*(const Var& a, const Var& b) { + return Var(VimplPtr(new MultVimplVV(a.vimpl(), b.vimpl()))); +} + +struct DivVimplVV : public OpVimplVV { + DivVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() / b->val(), a, b) { } + + void chain() { + a_->grad() += adj_ / b_->val(); + b_->grad() += adj_ * (a_->val() / (b_->val() * b_->val())); + } +}; + +inline Var operator/(const Var& a, const Var& b) { + return Var(VimplPtr(new DivVimplVV(a.vimpl(), b.vimpl()))); +} + + } \ No newline at end of file diff --git a/src/test.cpp b/src/test.cpp index 728fb3ef..795d845d 100644 --- a/src/test.cpp +++ b/src/test.cpp @@ -29,7 +29,7 @@ int main(int argc, char** argv) { Var y1 = layer(10, x1); Var y2 = layer(rand() % 20 + 1, x2); - Var y = y1 + log(y2); + Var y = sigma(log(y1) / log(y2)); set_zero_all_adjoints(); y.calc_gradients();