very cool

This commit is contained in:
Marcin Junczys-Dowmunt 2016-05-03 22:57:28 +01:00
parent 8a5f319bfb
commit 6a7c9316fc
3 changed files with 157 additions and 0 deletions

BIN
src/a.out Executable file

Binary file not shown.

134
src/mad.h Normal file
View File

@ -0,0 +1,134 @@
#pragma once
#include <memory>
#include <functional>
#include <vector>
#include <cmath>
#include <boost/pool/pool.hpp>
namespace mad {
typedef float Tensor;
boost::pool<> p(sizeof(char));
struct Chainable {
Chainable() { }
virtual ~Chainable() { }
virtual void chain() { }
virtual void init_dependent() { }
virtual void set_zero_adjoint() { }
static inline void* operator new(size_t nbytes) {
return p.ordered_malloc(nbytes);
}
};
std::vector<Chainable*> stack;
class Vimpl : public Chainable {
public:
Vimpl(const Tensor& t) : val_{std::move(t)}, adj_{0} {
stack.push_back(this);
}
~Vimpl() {};
virtual void init_dependent() { adj_ = 1; }
virtual void set_zero_adjoint() { adj_ = 0; }
const Tensor& val() const { return val_; };
Tensor& adj() { return adj_; };
protected:
const Tensor val_;
Tensor adj_;
};
typedef Vimpl* VimplPtr;
static void set_zero_all_adjoints() {
for(auto&& v : stack)
v->set_zero_adjoint();
}
static void grad(Chainable* v) {
typedef std::vector<Chainable*>::reverse_iterator It;
v->init_dependent();
for(It it = stack.rbegin(); it != stack.rend(); ++it) {
(*it)->chain();
}
}
class Var {
public:
Var() : vimpl_{nullptr} {}
Var(const Tensor& t) : vimpl_{new Vimpl{t}} {}
Var(const VimplPtr& vimpl) : vimpl_{vimpl} {}
const Tensor& val() const {
return vimpl_->val();
}
Tensor& adj() {
return vimpl_->adj();
}
VimplPtr vimpl() const {
return vimpl_;
}
void grad() {
mad::grad(vimpl_);
}
private:
VimplPtr vimpl_;
};
struct OpVimpl : public Vimpl {
OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { }
VimplPtr a_;
};
struct LogVimpl : public OpVimpl {
LogVimpl(VimplPtr a) : OpVimpl(std::log(a->val()), a) { }
void chain() {
a_->adj() += adj_ / a_->val();
}
};
inline Var log(const Var& a) {
return Var(VimplPtr(new LogVimpl(a.vimpl())));
}
///////////////////////////////////////////////////
struct OpVimplVV : public Vimpl {
VimplPtr a_;
VimplPtr b_;
OpVimplVV(Tensor t, VimplPtr a, VimplPtr b)
: Vimpl(t), a_(a), b_(b) { }
};
struct PlusVimplVV : public OpVimplVV {
PlusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() + b->val(), a, b) { }
void chain() {
a_->adj() += adj_;
b_->adj() += adj_;
}
};
inline Var operator+(const Var& a, const Var& b) {
return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl())));
}
}

23
src/test.cpp Normal file
View File

@ -0,0 +1,23 @@
#include <iostream>
#include "mad.h"
int main(int argc, char** argv) {
using namespace mad;
{
Var x0 = 1, x1 = 2, x2 = 3;
auto y = x0 + x0 + log(x2) + x1;
std::vector<Var> x = { x0, x1, x2 };
set_zero_all_adjoints();
y.grad();
std::cerr << "y = " << y.val() << std::endl;
for(int i = 0; i < x.size(); ++i)
std::cerr << "dy/dx_" << i << " = " << x[i].adj() << std::endl;
}
}