mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
very cool
This commit is contained in:
parent
8a5f319bfb
commit
6a7c9316fc
134
src/mad.h
Normal file
134
src/mad.h
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <functional>
|
||||||
|
#include <vector>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include <boost/pool/pool.hpp>
|
||||||
|
|
||||||
|
namespace mad {
|
||||||
|
|
||||||
|
typedef float Tensor;
|
||||||
|
|
||||||
|
boost::pool<> p(sizeof(char));
|
||||||
|
|
||||||
|
struct Chainable {
|
||||||
|
Chainable() { }
|
||||||
|
virtual ~Chainable() { }
|
||||||
|
|
||||||
|
virtual void chain() { }
|
||||||
|
virtual void init_dependent() { }
|
||||||
|
virtual void set_zero_adjoint() { }
|
||||||
|
|
||||||
|
static inline void* operator new(size_t nbytes) {
|
||||||
|
return p.ordered_malloc(nbytes);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<Chainable*> stack;
|
||||||
|
|
||||||
|
class Vimpl : public Chainable {
|
||||||
|
public:
|
||||||
|
Vimpl(const Tensor& t) : val_{std::move(t)}, adj_{0} {
|
||||||
|
stack.push_back(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
~Vimpl() {};
|
||||||
|
|
||||||
|
virtual void init_dependent() { adj_ = 1; }
|
||||||
|
virtual void set_zero_adjoint() { adj_ = 0; }
|
||||||
|
|
||||||
|
const Tensor& val() const { return val_; };
|
||||||
|
Tensor& adj() { return adj_; };
|
||||||
|
|
||||||
|
protected:
|
||||||
|
const Tensor val_;
|
||||||
|
Tensor adj_;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef Vimpl* VimplPtr;
|
||||||
|
|
||||||
|
static void set_zero_all_adjoints() {
|
||||||
|
for(auto&& v : stack)
|
||||||
|
v->set_zero_adjoint();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void grad(Chainable* v) {
|
||||||
|
typedef std::vector<Chainable*>::reverse_iterator It;
|
||||||
|
v->init_dependent();
|
||||||
|
for(It it = stack.rbegin(); it != stack.rend(); ++it) {
|
||||||
|
(*it)->chain();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class Var {
|
||||||
|
public:
|
||||||
|
Var() : vimpl_{nullptr} {}
|
||||||
|
Var(const Tensor& t) : vimpl_{new Vimpl{t}} {}
|
||||||
|
Var(const VimplPtr& vimpl) : vimpl_{vimpl} {}
|
||||||
|
|
||||||
|
const Tensor& val() const {
|
||||||
|
return vimpl_->val();
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor& adj() {
|
||||||
|
return vimpl_->adj();
|
||||||
|
}
|
||||||
|
|
||||||
|
VimplPtr vimpl() const {
|
||||||
|
return vimpl_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void grad() {
|
||||||
|
mad::grad(vimpl_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
VimplPtr vimpl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OpVimpl : public Vimpl {
|
||||||
|
OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { }
|
||||||
|
|
||||||
|
VimplPtr a_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct LogVimpl : public OpVimpl {
|
||||||
|
LogVimpl(VimplPtr a) : OpVimpl(std::log(a->val()), a) { }
|
||||||
|
|
||||||
|
void chain() {
|
||||||
|
a_->adj() += adj_ / a_->val();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Var log(const Var& a) {
|
||||||
|
return Var(VimplPtr(new LogVimpl(a.vimpl())));
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
struct OpVimplVV : public Vimpl {
|
||||||
|
VimplPtr a_;
|
||||||
|
VimplPtr b_;
|
||||||
|
|
||||||
|
OpVimplVV(Tensor t, VimplPtr a, VimplPtr b)
|
||||||
|
: Vimpl(t), a_(a), b_(b) { }
|
||||||
|
};
|
||||||
|
|
||||||
|
struct PlusVimplVV : public OpVimplVV {
|
||||||
|
PlusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() + b->val(), a, b) { }
|
||||||
|
|
||||||
|
void chain() {
|
||||||
|
a_->adj() += adj_;
|
||||||
|
b_->adj() += adj_;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Var operator+(const Var& a, const Var& b) {
|
||||||
|
return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl())));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
23
src/test.cpp
Normal file
23
src/test.cpp
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "mad.h"
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
|
||||||
|
using namespace mad;
|
||||||
|
{
|
||||||
|
Var x0 = 1, x1 = 2, x2 = 3;
|
||||||
|
|
||||||
|
auto y = x0 + x0 + log(x2) + x1;
|
||||||
|
|
||||||
|
std::vector<Var> x = { x0, x1, x2 };
|
||||||
|
|
||||||
|
|
||||||
|
set_zero_all_adjoints();
|
||||||
|
y.grad();
|
||||||
|
|
||||||
|
std::cerr << "y = " << y.val() << std::endl;
|
||||||
|
for(int i = 0; i < x.size(); ++i)
|
||||||
|
std::cerr << "dy/dx_" << i << " = " << x[i].adj() << std::endl;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user