mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
Add layer normalization to dl4mt model - only GPU
This commit is contained in:
parent
ff1d6bd309
commit
3fb9faedd1
@ -59,7 +59,11 @@ class Decoder {
|
|||||||
Temp2_.Resize(batchSize, SourceContext.Cols());
|
Temp2_.Resize(batchSize, SourceContext.Cols());
|
||||||
Mean(Temp2_, SourceContext, mapping);
|
Mean(Temp2_, SourceContext, mapping);
|
||||||
Prod(State, Temp2_, w_.Wi_);
|
Prod(State, Temp2_, w_.Wi_);
|
||||||
BroadcastVec(Tanh(_1 + _2), State, w_.Bi_);
|
if (w_.Gamma_) {
|
||||||
|
Normalization(State, State, w_.Gamma_, w_.Bi_, 1e-9);
|
||||||
|
} else {
|
||||||
|
BroadcastVec(Tanh(_1 + _2), State, w_.Bi_);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetNextState(mblas::Matrix& NextState,
|
void GetNextState(mblas::Matrix& NextState,
|
||||||
@ -108,6 +112,9 @@ class Decoder {
|
|||||||
void Init(const mblas::Matrix& SourceContext) {
|
void Init(const mblas::Matrix& SourceContext) {
|
||||||
using namespace mblas;
|
using namespace mblas;
|
||||||
Prod(/*h_[0],*/ SCU_, SourceContext, w_.U_);
|
Prod(/*h_[0],*/ SCU_, SourceContext, w_.U_);
|
||||||
|
if (w_.Gamma_1_) {
|
||||||
|
Normalization(SCU_, SCU_, w_.Gamma_1_, w_.B_, 1e-9);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetAlignedSourceContext(mblas::Matrix& AlignedSourceContext,
|
void GetAlignedSourceContext(mblas::Matrix& AlignedSourceContext,
|
||||||
@ -129,7 +136,11 @@ class Decoder {
|
|||||||
const size_t srcSize = mapping.size() / beamSizes.size();
|
const size_t srcSize = mapping.size() / beamSizes.size();
|
||||||
|
|
||||||
Prod(/*h_[1],*/ Temp2_, HiddenState, w_.W_);
|
Prod(/*h_[1],*/ Temp2_, HiddenState, w_.W_);
|
||||||
BroadcastVec(_1 + _2, Temp2_, w_.B_/*, s_[1]*/);
|
if (w_.Gamma_2_) {
|
||||||
|
Normalization(Temp2_, Temp2_, w_.Gamma_2_, 1e-9);
|
||||||
|
} else {
|
||||||
|
BroadcastVec(_1 + _2, Temp2_, w_.B_/*, s_[1]*/);
|
||||||
|
}
|
||||||
|
|
||||||
Copy(Temp1_, SCU_);
|
Copy(Temp1_, SCU_);
|
||||||
Broadcast(Tanh(_1 + _2), Temp1_, Temp2_, dBatchMapping_, srcSize);
|
Broadcast(Tanh(_1 + _2), Temp1_, Temp2_, dBatchMapping_, srcSize);
|
||||||
@ -192,12 +203,26 @@ class Decoder {
|
|||||||
using namespace mblas;
|
using namespace mblas;
|
||||||
|
|
||||||
Prod(/*h_[0],*/ T1_, State, w_.W1_);
|
Prod(/*h_[0],*/ T1_, State, w_.W1_);
|
||||||
Prod(/*h_[1],*/ T2_, Embedding, w_.W2_);
|
|
||||||
Prod(/*h_[2],*/ T3_, AlignedSourceContext, w_.W3_);
|
|
||||||
|
|
||||||
BroadcastVec(_1 + _2, T1_, w_.B1_ /*,s_[0]*/);
|
if (w_.Gamma_1_) {
|
||||||
BroadcastVec(_1 + _2, T2_, w_.B2_ /*,s_[1]*/);
|
Normalization(T1_, T1_, w_.Gamma_1_, w_.B1_, 1e-9);
|
||||||
BroadcastVec(_1 + _2, T3_, w_.B3_ /*,s_[2]*/);
|
} else {
|
||||||
|
BroadcastVec(_1 + _2, T1_, w_.B1_ /*,s_[0]*/);
|
||||||
|
}
|
||||||
|
|
||||||
|
Prod(/*h_[1],*/ T2_, Embedding, w_.W2_);
|
||||||
|
if (w_.Gamma_0_) {
|
||||||
|
Normalization(T2_, T2_, w_.Gamma_0_, w_.B2_, 1e-9);
|
||||||
|
} else {
|
||||||
|
BroadcastVec(_1 + _2, T2_, w_.B2_ /*,s_[1]*/);
|
||||||
|
}
|
||||||
|
|
||||||
|
Prod(/*h_[2],*/ T3_, AlignedSourceContext, w_.W3_);
|
||||||
|
if (w_.Gamma_2_) {
|
||||||
|
Normalization(T3_, T3_, w_.Gamma_2_, w_.B3_, 1e-9);
|
||||||
|
} else {
|
||||||
|
BroadcastVec(_1 + _2, T3_, w_.B3_ /*,s_[2]*/);
|
||||||
|
}
|
||||||
|
|
||||||
Element(Tanh(_1 + _2 + _3), T1_, T2_, T3_);
|
Element(Tanh(_1 + _2 + _3), T1_, T2_, T3_);
|
||||||
|
|
||||||
|
@ -70,8 +70,7 @@ class Encoder {
|
|||||||
if(invert) {
|
if(invert) {
|
||||||
mblas::MapMatrix(State_, *mapping, n - i - 1);
|
mblas::MapMatrix(State_, *mapping, n - i - 1);
|
||||||
mblas::PasteRows(Context, State_, (n - i - 1), gru_.GetStateLength(), n);
|
mblas::PasteRows(Context, State_, (n - i - 1), gru_.GetStateLength(), n);
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
mblas::PasteRows(Context, State_, i, 0, n);
|
mblas::PasteRows(Context, State_, i, 0, n);
|
||||||
}
|
}
|
||||||
++i;
|
++i;
|
||||||
@ -83,11 +82,8 @@ class Encoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Model matrices
|
|
||||||
const GRU<Weights> gru_;
|
const GRU<Weights> gru_;
|
||||||
|
|
||||||
mblas::Matrix State_;
|
mblas::Matrix State_;
|
||||||
|
|
||||||
RNN(const RNN&) = delete;
|
RNN(const RNN&) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -104,9 +104,16 @@ class FastGRU {
|
|||||||
const mblas::Matrix& Context) const {
|
const mblas::Matrix& Context) const {
|
||||||
using namespace mblas;
|
using namespace mblas;
|
||||||
|
|
||||||
// const size_t cols = GetStateLength();
|
|
||||||
Prod(RUH_, Context, WWx_);
|
Prod(RUH_, Context, WWx_);
|
||||||
|
if (w_.Gamma_1_) {
|
||||||
|
Normalization(RUH_, RUH_, w_.Gamma_1_, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
Prod(Temp_, State, UUx_);
|
Prod(Temp_, State, UUx_);
|
||||||
|
if (w_.Gamma_2_) {
|
||||||
|
Normalization(Temp_, Temp_, w_.Gamma_2_, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
ElementwiseOps(NextState, State, RUH_, Temp_);
|
ElementwiseOps(NextState, State, RUH_, Temp_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user