mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-05 01:31:46 +03:00
more operators
This commit is contained in:
parent
b9e26509dd
commit
28ba2628f9
@ -1,5 +1,7 @@
|
||||
# Marian*
|
||||
|
||||
Parallel Automatic Differentiation Library
|
||||
A C++ gpu-specific parallel automatic differentiation library
|
||||
with operator overloading.
|
||||
|
||||
`*` = in honour of Marian Rejewski, a Polish mathematician and cryptologist.
|
||||
'*' = in honour of Marian Rejewski, a Polish mathematician and
|
||||
cryptologist.
|
||||
|
80
src/marian.h
80
src/marian.h
@ -89,6 +89,8 @@ class Var {
|
||||
VimplPtr vimpl_;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////
|
||||
|
||||
struct OpVimpl : public Vimpl {
|
||||
OpVimpl(const Tensor& t, VimplPtr a) : Vimpl(t), a_(a) { }
|
||||
|
||||
@ -108,6 +110,44 @@ inline Var log(const Var& a) {
|
||||
return Var(VimplPtr(new LogVimpl(a.vimpl())));
|
||||
}
|
||||
|
||||
struct ExpVimpl : public OpVimpl {
|
||||
ExpVimpl(VimplPtr a) : OpVimpl(std::exp(a->val()), a) { }
|
||||
|
||||
void chain() {
|
||||
a_->grad() += adj_ * std::exp(a_->val());
|
||||
}
|
||||
};
|
||||
|
||||
inline Var exp(const Var& a) {
|
||||
return Var(VimplPtr(new ExpVimpl(a.vimpl())));
|
||||
}
|
||||
|
||||
struct NegVimpl : public OpVimpl {
|
||||
NegVimpl(VimplPtr a) : OpVimpl(-a->val(), a) { }
|
||||
|
||||
void chain() {
|
||||
a_->grad() -= adj_;
|
||||
}
|
||||
};
|
||||
|
||||
inline Var operator-(const Var& a) {
|
||||
return Var(VimplPtr(new NegVimpl(a.vimpl())));
|
||||
}
|
||||
|
||||
// @TODO: take care of large exponents
|
||||
struct SigmaVimpl : public OpVimpl {
|
||||
SigmaVimpl(VimplPtr a) : OpVimpl(1.f / (1.f + std::exp(-a->val())), a) { }
|
||||
|
||||
void chain() {
|
||||
Tensor l = 1.f / (1.f + std::exp(-a_->val()));
|
||||
a_->grad() += adj_ * l * (1 - l);
|
||||
}
|
||||
};
|
||||
|
||||
inline Var sigma(const Var& a) {
|
||||
return Var(VimplPtr(new SigmaVimpl(a.vimpl())));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////
|
||||
|
||||
|
||||
@ -132,4 +172,44 @@ inline Var operator+(const Var& a, const Var& b) {
|
||||
return Var(VimplPtr(new PlusVimplVV(a.vimpl(), b.vimpl())));
|
||||
}
|
||||
|
||||
struct MinusVimplVV : public OpVimplVV {
|
||||
MinusVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() - b->val(), a, b) { }
|
||||
|
||||
void chain() {
|
||||
a_->grad() -= adj_;
|
||||
b_->grad() -= adj_;
|
||||
}
|
||||
};
|
||||
|
||||
inline Var operator-(const Var& a, const Var& b) {
|
||||
return Var(VimplPtr(new MinusVimplVV(a.vimpl(), b.vimpl())));
|
||||
}
|
||||
|
||||
struct MultVimplVV : public OpVimplVV {
|
||||
MultVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() * b->val(), a, b) { }
|
||||
|
||||
void chain() {
|
||||
a_->grad() += adj_ * b_->val();
|
||||
b_->grad() += adj_ * a_->val();
|
||||
}
|
||||
};
|
||||
|
||||
inline Var operator*(const Var& a, const Var& b) {
|
||||
return Var(VimplPtr(new MultVimplVV(a.vimpl(), b.vimpl())));
|
||||
}
|
||||
|
||||
struct DivVimplVV : public OpVimplVV {
|
||||
DivVimplVV(VimplPtr a, VimplPtr b) : OpVimplVV(a->val() / b->val(), a, b) { }
|
||||
|
||||
void chain() {
|
||||
a_->grad() += adj_ / b_->val();
|
||||
b_->grad() += adj_ * (a_->val() / (b_->val() * b_->val()));
|
||||
}
|
||||
};
|
||||
|
||||
inline Var operator/(const Var& a, const Var& b) {
|
||||
return Var(VimplPtr(new DivVimplVV(a.vimpl(), b.vimpl())));
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -29,7 +29,7 @@ int main(int argc, char** argv) {
|
||||
Var y1 = layer(10, x1);
|
||||
Var y2 = layer(rand() % 20 + 1, x2);
|
||||
|
||||
Var y = y1 + log(y2);
|
||||
Var y = sigma(log(y1) / log(y2));
|
||||
|
||||
set_zero_all_adjoints();
|
||||
y.calc_gradients();
|
||||
|
Loading…
Reference in New Issue
Block a user