more operators

This commit is contained in:
Marcin Junczys-Dowmunt 2016-05-04 11:07:44 +01:00
parent b9e26509dd
commit 28ba2628f9
3 changed files with 85 additions and 3 deletions

View File

@ -1,5 +1,7 @@
# Marian*
Parallel Automatic Differentiation Library
A C++ gpu-specific parallel automatic differentiation library
with operator overloading.
`*` = in honour of Marian Rejewski, a Polish mathematician and cryptologist.
'*' = in honour of Marian Rejewski, a Polish mathematician and
cryptologist.

View File

@ -89,6 +89,8 @@ class Var {
VimplPtr vimpl_;
};
///////////////////////////////////////////////////
struct OpVimpl : public Vimpl {
OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { }
@ -108,6 +110,44 @@ inline Var log(const Var& a) {
return Var(VimplPtr(new LogVimpl(a.vimpl())));
}
struct ExpVimpl : public OpVimpl {
ExpVimpl(VimplPtr a) : OpVimpl(std::exp(a->val()), a) { }
void chain() {
a_->grad() += adj_ * std::exp(a_->val());
}
};
inline Var exp(const Var& a) {
return Var(VimplPtr(new ExpVimpl(a.vimpl())));
}
struct NegVimpl : public OpVimpl {
NegVimpl(VimplPtr a) : OpVimpl(-a->val(), a) { }
void chain() {
a_->grad() -= adj_;
}
};
inline Var operator-(const Var& a) {
return Var(VimplPtr(new NegVimpl(a.vimpl())));
}
// @TODO: take care of large exponents
struct SigmaVimpl : public OpVimpl {
SigmaVimpl(VimplPtr a) : OpVimpl(1.f / (1.f + std::exp(-a->val())), a) { }
void chain() {
Tensor l = 1.f / (1.f + std::exp(-a_->val()));
a_->grad() += adj_ * l * (1 - l);
}
};
inline Var sigma(const Var& a) {
return Var(VimplPtr(new SigmaVimpl(a.vimpl())));
}
///////////////////////////////////////////////////
@ -132,4 +172,44 @@ inline Var operator+(const Var& a, const Var& b) {
return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl())));
}
struct MinusVimplVV : public OpVimplVV {
MinusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() - b->val(), a, b) { }
void chain() {
a_->grad() -= adj_;
b_->grad() -= adj_;
}
};
inline Var operator-(const Var& a, const Var& b) {
return Var(VimplPtr(new MinusVimplVV(a.vimpl(), b.vimpl())));
}
struct MultVimplVV : public OpVimplVV {
MultVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() * b->val(), a, b) { }
void chain() {
a_->grad() += adj_ * b_->val();
b_->grad() += adj_ * a_->val();
}
};
inline Var operator*(const Var& a, const Var& b) {
return Var(VimplPtr(new MultVimplVV(a.vimpl(), b.vimpl())));
}
struct DivVimplVV : public OpVimplVV {
DivVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() / b->val(), a, b) { }
void chain() {
a_->grad() += adj_ / b_->val();
b_->grad() += adj_ * (a_->val() / (b_->val() * b_->val()));
}
};
inline Var operator/(const Var& a, const Var& b) {
return Var(VimplPtr(new DivVimplVV(a.vimpl(), b.vimpl())));
}
}

View File

@ -29,7 +29,7 @@ int main(int argc, char** argv) {
Var y1 = layer(10, x1);
Var y2 = layer(rand() % 20 + 1, x2);
Var y = y1 + log(y2);
Var y = sigma(log(y1) / log(y2));
set_zero_all_adjoints();
y.calc_gradients();