mirror of
https://github.com/marian-nmt/marian.git
synced 2024-09-17 09:47:34 +03:00
more operators
This commit is contained in:
parent
b9e26509dd
commit
28ba2628f9
@ -1,5 +1,7 @@
|
|||||||
# Marian*
|
# Marian*
|
||||||
|
|
||||||
Parallel Automatic Differentiation Library
|
A C++ gpu-specific parallel automatic differentiation library
|
||||||
|
with operator overloading.
|
||||||
|
|
||||||
`*` = in honour of Marian Rejewski, a Polish mathematician and cryptologist.
|
'*' = in honour of Marian Rejewski, a Polish mathematician and
|
||||||
|
cryptologist.
|
||||||
|
80
src/marian.h
80
src/marian.h
@ -89,6 +89,8 @@ class Var {
|
|||||||
VimplPtr vimpl_;
|
VimplPtr vimpl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////
|
||||||
|
|
||||||
struct OpVimpl : public Vimpl {
|
struct OpVimpl : public Vimpl {
|
||||||
OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { }
|
OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { }
|
||||||
|
|
||||||
@ -108,6 +110,44 @@ inline Var log(const Var& a) {
|
|||||||
return Var(VimplPtr(new LogVimpl(a.vimpl())));
|
return Var(VimplPtr(new LogVimpl(a.vimpl())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ExpVimpl : public OpVimpl {
|
||||||
|
ExpVimpl(VimplPtr a) : OpVimpl(std::exp(a->val()), a) { }
|
||||||
|
|
||||||
|
void chain() {
|
||||||
|
a_->grad() += adj_ * std::exp(a_->val());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Var exp(const Var& a) {
|
||||||
|
return Var(VimplPtr(new ExpVimpl(a.vimpl())));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NegVimpl : public OpVimpl {
|
||||||
|
NegVimpl(VimplPtr a) : OpVimpl(-a->val(), a) { }
|
||||||
|
|
||||||
|
void chain() {
|
||||||
|
a_->grad() -= adj_;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Var operator-(const Var& a) {
|
||||||
|
return Var(VimplPtr(new NegVimpl(a.vimpl())));
|
||||||
|
}
|
||||||
|
|
||||||
|
// @TODO: take care of large exponents
|
||||||
|
struct SigmaVimpl : public OpVimpl {
|
||||||
|
SigmaVimpl(VimplPtr a) : OpVimpl(1.f / (1.f + std::exp(-a->val())), a) { }
|
||||||
|
|
||||||
|
void chain() {
|
||||||
|
Tensor l = 1.f / (1.f + std::exp(-a_->val()));
|
||||||
|
a_->grad() += adj_ * l * (1 - l);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Var sigma(const Var& a) {
|
||||||
|
return Var(VimplPtr(new SigmaVimpl(a.vimpl())));
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////
|
///////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
@ -132,4 +172,44 @@ inline Var operator+(const Var& a, const Var& b) {
|
|||||||
return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl())));
|
return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl())));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct MinusVimplVV : public OpVimplVV {
|
||||||
|
MinusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() - b->val(), a, b) { }
|
||||||
|
|
||||||
|
void chain() {
|
||||||
|
a_->grad() -= adj_;
|
||||||
|
b_->grad() -= adj_;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Var operator-(const Var& a, const Var& b) {
|
||||||
|
return Var(VimplPtr(new MinusVimplVV(a.vimpl(), b.vimpl())));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MultVimplVV : public OpVimplVV {
|
||||||
|
MultVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() * b->val(), a, b) { }
|
||||||
|
|
||||||
|
void chain() {
|
||||||
|
a_->grad() += adj_ * b_->val();
|
||||||
|
b_->grad() += adj_ * a_->val();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Var operator*(const Var& a, const Var& b) {
|
||||||
|
return Var(VimplPtr(new MultVimplVV(a.vimpl(), b.vimpl())));
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DivVimplVV : public OpVimplVV {
|
||||||
|
DivVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() / b->val(), a, b) { }
|
||||||
|
|
||||||
|
void chain() {
|
||||||
|
a_->grad() += adj_ / b_->val();
|
||||||
|
b_->grad() += adj_ * (a_->val() / (b_->val() * b_->val()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline Var operator/(const Var& a, const Var& b) {
|
||||||
|
return Var(VimplPtr(new DivVimplVV(a.vimpl(), b.vimpl())));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
@ -29,7 +29,7 @@ int main(int argc, char** argv) {
|
|||||||
Var y1 = layer(10, x1);
|
Var y1 = layer(10, x1);
|
||||||
Var y2 = layer(rand() % 20 + 1, x2);
|
Var y2 = layer(rand() % 20 + 1, x2);
|
||||||
|
|
||||||
Var y = y1 + log(y2);
|
Var y = sigma(log(y1) / log(y2));
|
||||||
|
|
||||||
set_zero_all_adjoints();
|
set_zero_all_adjoints();
|
||||||
y.calc_gradients();
|
y.calc_gradients();
|
||||||
|
Loading…
Reference in New Issue
Block a user