mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
Add layer normalization to dl4mt model - only GPU
This commit is contained in:
parent
ff1d6bd309
commit
3fb9faedd1
@ -59,7 +59,11 @@ class Decoder {
|
||||
Temp2_.Resize(batchSize, SourceContext.Cols());
|
||||
Mean(Temp2_, SourceContext, mapping);
|
||||
Prod(State, Temp2_, w_.Wi_);
|
||||
BroadcastVec(Tanh(_1 + _2), State, w_.Bi_);
|
||||
if (w_.Gamma_) {
|
||||
Normalization(State, State, w_.Gamma_, w_.Bi_, 1e-9);
|
||||
} else {
|
||||
BroadcastVec(Tanh(_1 + _2), State, w_.Bi_);
|
||||
}
|
||||
}
|
||||
|
||||
void GetNextState(mblas::Matrix& NextState,
|
||||
@ -108,6 +112,9 @@ class Decoder {
|
||||
void Init(const mblas::Matrix& SourceContext) {
|
||||
using namespace mblas;
|
||||
Prod(/*h_[0],*/ SCU_, SourceContext, w_.U_);
|
||||
if (w_.Gamma_1_) {
|
||||
Normalization(SCU_, SCU_, w_.Gamma_1_, w_.B_, 1e-9);
|
||||
}
|
||||
}
|
||||
|
||||
void GetAlignedSourceContext(mblas::Matrix& AlignedSourceContext,
|
||||
@ -129,7 +136,11 @@ class Decoder {
|
||||
const size_t srcSize = mapping.size() / beamSizes.size();
|
||||
|
||||
Prod(/*h_[1],*/ Temp2_, HiddenState, w_.W_);
|
||||
BroadcastVec(_1 + _2, Temp2_, w_.B_/*, s_[1]*/);
|
||||
if (w_.Gamma_2_) {
|
||||
Normalization(Temp2_, Temp2_, w_.Gamma_2_, 1e-9);
|
||||
} else {
|
||||
BroadcastVec(_1 + _2, Temp2_, w_.B_/*, s_[1]*/);
|
||||
}
|
||||
|
||||
Copy(Temp1_, SCU_);
|
||||
Broadcast(Tanh(_1 + _2), Temp1_, Temp2_, dBatchMapping_, srcSize);
|
||||
@ -192,12 +203,26 @@ class Decoder {
|
||||
using namespace mblas;
|
||||
|
||||
Prod(/*h_[0],*/ T1_, State, w_.W1_);
|
||||
Prod(/*h_[1],*/ T2_, Embedding, w_.W2_);
|
||||
Prod(/*h_[2],*/ T3_, AlignedSourceContext, w_.W3_);
|
||||
|
||||
BroadcastVec(_1 + _2, T1_, w_.B1_ /*,s_[0]*/);
|
||||
BroadcastVec(_1 + _2, T2_, w_.B2_ /*,s_[1]*/);
|
||||
BroadcastVec(_1 + _2, T3_, w_.B3_ /*,s_[2]*/);
|
||||
if (w_.Gamma_1_) {
|
||||
Normalization(T1_, T1_, w_.Gamma_1_, w_.B1_, 1e-9);
|
||||
} else {
|
||||
BroadcastVec(_1 + _2, T1_, w_.B1_ /*,s_[0]*/);
|
||||
}
|
||||
|
||||
Prod(/*h_[1],*/ T2_, Embedding, w_.W2_);
|
||||
if (w_.Gamma_0_) {
|
||||
Normalization(T2_, T2_, w_.Gamma_0_, w_.B2_, 1e-9);
|
||||
} else {
|
||||
BroadcastVec(_1 + _2, T2_, w_.B2_ /*,s_[1]*/);
|
||||
}
|
||||
|
||||
Prod(/*h_[2],*/ T3_, AlignedSourceContext, w_.W3_);
|
||||
if (w_.Gamma_2_) {
|
||||
Normalization(T3_, T3_, w_.Gamma_2_, w_.B3_, 1e-9);
|
||||
} else {
|
||||
BroadcastVec(_1 + _2, T3_, w_.B3_ /*,s_[2]*/);
|
||||
}
|
||||
|
||||
Element(Tanh(_1 + _2 + _3), T1_, T2_, T3_);
|
||||
|
||||
|
@ -70,8 +70,7 @@ class Encoder {
|
||||
if(invert) {
|
||||
mblas::MapMatrix(State_, *mapping, n - i - 1);
|
||||
mblas::PasteRows(Context, State_, (n - i - 1), gru_.GetStateLength(), n);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
mblas::PasteRows(Context, State_, i, 0, n);
|
||||
}
|
||||
++i;
|
||||
@ -83,11 +82,8 @@ class Encoder {
|
||||
}
|
||||
|
||||
private:
|
||||
// Model matrices
|
||||
const GRU<Weights> gru_;
|
||||
|
||||
mblas::Matrix State_;
|
||||
|
||||
RNN(const RNN&) = delete;
|
||||
};
|
||||
|
||||
|
@ -104,9 +104,16 @@ class FastGRU {
|
||||
const mblas::Matrix& Context) const {
|
||||
using namespace mblas;
|
||||
|
||||
// const size_t cols = GetStateLength();
|
||||
Prod(RUH_, Context, WWx_);
|
||||
if (w_.Gamma_1_) {
|
||||
Normalization(RUH_, RUH_, w_.Gamma_1_, 1e-9);
|
||||
}
|
||||
|
||||
Prod(Temp_, State, UUx_);
|
||||
if (w_.Gamma_2_) {
|
||||
Normalization(Temp_, Temp_, w_.Gamma_2_, 1e-9);
|
||||
}
|
||||
|
||||
ElementwiseOps(NextState, State, RUH_, Temp_);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user