diff --git a/src/a.out b/src/a.out new file mode 100755 index 00000000..765a2aa9 Binary files /dev/null and b/src/a.out differ diff --git a/src/mad.h b/src/mad.h new file mode 100644 index 00000000..0c4b079e --- /dev/null +++ b/src/mad.h @@ -0,0 +1,134 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace mad { + +typedef float Tensor; + +boost::pool<> p(sizeof(char)); + +struct Chainable { + Chainable() { } + virtual ~Chainable() { } + + virtual void chain() { } + virtual void init_dependent() { } + virtual void set_zero_adjoint() { } + + static inline void* operator new(size_t nbytes) { + return p.ordered_malloc(nbytes); + } +}; + +std::vector stack; + +class Vimpl : public Chainable { + public: + Vimpl(const Tensor& t) : val_{std::move(t)}, adj_{0} { + stack.push_back(this); + } + + ~Vimpl() {}; + + virtual void init_dependent() { adj_ = 1; } + virtual void set_zero_adjoint() { adj_ = 0; } + + const Tensor& val() const { return val_; }; + Tensor& adj() { return adj_; }; + + protected: + const Tensor val_; + Tensor adj_; +}; + +typedef Vimpl* VimplPtr; + +static void set_zero_all_adjoints() { + for(auto&& v : stack) + v->set_zero_adjoint(); +} + +static void grad(Chainable* v) { + typedef std::vector::reverse_iterator It; + v->init_dependent(); + for(It it = stack.rbegin(); it != stack.rend(); ++it) { + (*it)->chain(); + } +} + +class Var { + public: + Var() : vimpl_{nullptr} {} + Var(const Tensor& t) : vimpl_{new Vimpl{t}} {} + Var(const VimplPtr& vimpl) : vimpl_{vimpl} {} + + const Tensor& val() const { + return vimpl_->val(); + } + + Tensor& adj() { + return vimpl_->adj(); + } + + VimplPtr vimpl() const { + return vimpl_; + } + + void grad() { + mad::grad(vimpl_); + } + + private: + VimplPtr vimpl_; +}; + +struct OpVimpl : public Vimpl { + OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { } + + VimplPtr a_; +}; + + +struct LogVimpl : public OpVimpl { + LogVimpl(VimplPtr a) : OpVimpl(std::log(a->val()), a) { } + + void chain() { + a_->adj() += adj_ / a_->val(); + } +}; + +inline Var log(const Var& a) { + return Var(VimplPtr(new LogVimpl(a.vimpl()))); +} + +/////////////////////////////////////////////////// + + +struct OpVimplVV : public Vimpl { + VimplPtr a_; + VimplPtr b_; + + OpVimplVV(Tensor t, VimplPtr a, VimplPtr b) + : Vimpl(t), a_(a), b_(b) { } +}; + +struct PlusVimplVV : public OpVimplVV { + PlusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() + b->val(), a, b) { } + + void chain() { + a_->adj() += adj_; + b_->adj() += adj_; + } +}; + +inline Var operator+(const Var& a, const Var& b) { + return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl()))); +} + +} \ No newline at end of file diff --git a/src/test.cpp b/src/test.cpp new file mode 100644 index 00000000..6f29902c --- /dev/null +++ b/src/test.cpp @@ -0,0 +1,23 @@ +#include + +#include "mad.h" + +int main(int argc, char** argv) { + + using namespace mad; + { + Var x0 = 1, x1 = 2, x2 = 3; + + auto y = x0 + x0 + log(x2) + x1; + + std::vector x = { x0, x1, x2 }; + + + set_zero_all_adjoints(); + y.grad(); + + std::cerr << "y = " << y.val() << std::endl; + for(int i = 0; i < x.size(); ++i) + std::cerr << "dy/dx_" << i << " = " << x[i].adj() << std::endl; + } +} \ No newline at end of file