Add layer normalization to dl4mt model - only GPU

This commit is contained in:
Tomasz Dwojak 2017-03-21 12:33:49 +00:00
parent ff1d6bd309
commit 3fb9faedd1
3 changed files with 41 additions and 13 deletions

View File

@ -59,7 +59,11 @@ class Decoder {
Temp2_.Resize(batchSize, SourceContext.Cols());
Mean(Temp2_, SourceContext, mapping);
Prod(State, Temp2_, w_.Wi_);
BroadcastVec(Tanh(_1 + _2), State, w_.Bi_);
if (w_.Gamma_) {
Normalization(State, State, w_.Gamma_, w_.Bi_, 1e-9);
} else {
BroadcastVec(Tanh(_1 + _2), State, w_.Bi_);
}
}
void GetNextState(mblas::Matrix& NextState,
@ -108,6 +112,9 @@ class Decoder {
void Init(const mblas::Matrix& SourceContext) {
using namespace mblas;
Prod(/*h_[0],*/ SCU_, SourceContext, w_.U_);
if (w_.Gamma_1_) {
Normalization(SCU_, SCU_, w_.Gamma_1_, w_.B_, 1e-9);
}
}
void GetAlignedSourceContext(mblas::Matrix& AlignedSourceContext,
@ -129,7 +136,11 @@ class Decoder {
const size_t srcSize = mapping.size() / beamSizes.size();
Prod(/*h_[1],*/ Temp2_, HiddenState, w_.W_);
BroadcastVec(_1 + _2, Temp2_, w_.B_/*, s_[1]*/);
if (w_.Gamma_2_) {
Normalization(Temp2_, Temp2_, w_.Gamma_2_, 1e-9);
} else {
BroadcastVec(_1 + _2, Temp2_, w_.B_/*, s_[1]*/);
}
Copy(Temp1_, SCU_);
Broadcast(Tanh(_1 + _2), Temp1_, Temp2_, dBatchMapping_, srcSize);
@ -192,12 +203,26 @@ class Decoder {
using namespace mblas;
Prod(/*h_[0],*/ T1_, State, w_.W1_);
Prod(/*h_[1],*/ T2_, Embedding, w_.W2_);
Prod(/*h_[2],*/ T3_, AlignedSourceContext, w_.W3_);
BroadcastVec(_1 + _2, T1_, w_.B1_ /*,s_[0]*/);
BroadcastVec(_1 + _2, T2_, w_.B2_ /*,s_[1]*/);
BroadcastVec(_1 + _2, T3_, w_.B3_ /*,s_[2]*/);
if (w_.Gamma_1_) {
Normalization(T1_, T1_, w_.Gamma_1_, w_.B1_, 1e-9);
} else {
BroadcastVec(_1 + _2, T1_, w_.B1_ /*,s_[0]*/);
}
Prod(/*h_[1],*/ T2_, Embedding, w_.W2_);
if (w_.Gamma_0_) {
Normalization(T2_, T2_, w_.Gamma_0_, w_.B2_, 1e-9);
} else {
BroadcastVec(_1 + _2, T2_, w_.B2_ /*,s_[1]*/);
}
Prod(/*h_[2],*/ T3_, AlignedSourceContext, w_.W3_);
if (w_.Gamma_2_) {
Normalization(T3_, T3_, w_.Gamma_2_, w_.B3_, 1e-9);
} else {
BroadcastVec(_1 + _2, T3_, w_.B3_ /*,s_[2]*/);
}
Element(Tanh(_1 + _2 + _3), T1_, T2_, T3_);

View File

@ -70,8 +70,7 @@ class Encoder {
if(invert) {
mblas::MapMatrix(State_, *mapping, n - i - 1);
mblas::PasteRows(Context, State_, (n - i - 1), gru_.GetStateLength(), n);
}
else {
} else {
mblas::PasteRows(Context, State_, i, 0, n);
}
++i;
@ -83,11 +82,8 @@ class Encoder {
}
private:
// Model matrices
const GRU<Weights> gru_;
mblas::Matrix State_;
RNN(const RNN&) = delete;
};

View File

@ -104,9 +104,16 @@ class FastGRU {
const mblas::Matrix& Context) const {
using namespace mblas;
// const size_t cols = GetStateLength();
Prod(RUH_, Context, WWx_);
if (w_.Gamma_1_) {
Normalization(RUH_, RUH_, w_.Gamma_1_, 1e-9);
}
Prod(Temp_, State, UUx_);
if (w_.Gamma_2_) {
Normalization(Temp_, Temp_, w_.Gamma_2_, 1e-9);
}
ElementwiseOps(NextState, State, RUH_, Temp_);
}