From 73a71a1fcb2e422f68d4ff2a6251c22f23930f23 Mon Sep 17 00:00:00 2001 From: Tomasz Dwojak Date: Tue, 21 Mar 2017 13:50:04 +0000 Subject: [PATCH] Add layer normalization matrices to cpu model --- src/cpu/dl4mt/model.cpp | 50 ++++++++++++++++++++++++----------------- src/cpu/dl4mt/model.h | 10 +++++++++ 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/src/cpu/dl4mt/model.cpp b/src/cpu/dl4mt/model.cpp index 6e793b3c..02e77a13 100644 --- a/src/cpu/dl4mt/model.cpp +++ b/src/cpu/dl4mt/model.cpp @@ -25,7 +25,8 @@ Weights::GRU::GRU(const NpzConverter& model, const std::vector &key Weights::DecInit::DecInit(const NpzConverter& model) : Wi_(model["ff_state_W"]), - Bi_(model("ff_state_b", true)) + Bi_(model("ff_state_b", true)), + Gamma_(model["ff_state_gamma"]) {} Weights::DecGRU2::DecGRU2(const NpzConverter& model) @@ -35,17 +36,21 @@ Weights::DecGRU2::DecGRU2(const NpzConverter& model) Wx_(model["decoder_Wcx"]), Bx2_(model("decoder_bx_nl", true)), Bx1_(Bx2_.rows(), Bx2_.columns()), - Ux_(model["decoder_Ux_nl"]) + Ux_(model["decoder_Ux_nl"]), + Gamma_1_(model["decoder_cell2_gamma1"]), + Gamma_2_(model["decoder_cell2_gamma2"]) { const_cast(Bx1_) = 0.0f; } Weights::DecAttention::DecAttention(const NpzConverter& model) : V_(model("decoder_U_att", true)), -W_(model["decoder_W_comb_att"]), -B_(model("decoder_b_att", true)), -U_(model["decoder_Wc_att"]), -C_(model["decoder_c_tt"]) // scalar? + W_(model["decoder_W_comb_att"]), + B_(model("decoder_b_att", true)), + U_(model["decoder_Wc_att"]), + C_(model["decoder_c_tt"]), // scalar? + Gamma_1_(model["decoder_att_gamma1"]), + Gamma_2_(model["decoder_att_gamma2"]) {} Weights::DecSoftmax::DecSoftmax(const NpzConverter& model) @@ -56,25 +61,28 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model) W3_(model["ff_logit_ctx_W"]), B3_(model("ff_logit_ctx_b", true)), W4_(model["ff_logit_W"]), - B4_(model("ff_logit_b", true)) + B4_(model("ff_logit_b", true)), + Gamma_0_(model["ff_logit_l1_gamma0"]), + Gamma_1_(model["ff_logit_l1_gamma1"]), + Gamma_2_(model["ff_logit_l1_gamma2"]) {} ////////////////////////////////////////////////////////////////////////////// Weights::Weights(const NpzConverter& model, size_t) : encEmbeddings_(model, "Wemb"), -encForwardGRU_(model, {"encoder_W", "encoder_b", "encoder_U", "encoder_Wx", "encoder_bx", "encoder_Ux"}), -encBackwardGRU_(model, {"encoder_r_W", "encoder_r_b", "encoder_r_U", "encoder_r_Wx", "encoder_r_bx", "encoder_r_Ux"}), -decEmbeddings_(model, "Wemb_dec"), -decInit_(model), -decGru1_(model, {"decoder_W", "decoder_b", "decoder_U", "decoder_Wx", "decoder_bx", "decoder_Ux"}), -decGru2_(model), -decAttention_(model), -decSoftmax_(model) -{ - //cerr << *this << endl; -} - -} -} + encForwardGRU_(model, {"encoder_W", "encoder_b", "encoder_U", "encoder_Wx", "encoder_bx", + "encoder_Ux", "encoder_gamma1", "encoder_gamma2"}), + encBackwardGRU_(model, {"encoder_r_W", "encoder_r_b", "encoder_r_U", "encoder_r_Wx", + "encoder_r_bx", "encoder_r_Ux", "encoder_r_gamma1", "encoder_r_gamma2"}), + decEmbeddings_(model, "Wemb_dec"), + decInit_(model), + decGru1_(model, {"decoder_W", "decoder_b", "decoder_U", "decoder_Wx", "decoder_bx", "decoder_Ux", + "decoder_cell1_gamma1", "decoder_cell1_gamma2"}), + decGru2_(model), + decAttention_(model), + decSoftmax_(model) +{} +} // namespace cpu +} // namespace amunmt diff --git a/src/cpu/dl4mt/model.h b/src/cpu/dl4mt/model.h index b740e99e..d89287b4 100644 --- a/src/cpu/dl4mt/model.h +++ b/src/cpu/dl4mt/model.h @@ -31,6 +31,8 @@ struct Weights { const mblas::Matrix Bx1_; const mblas::Matrix Bx2_; const mblas::Matrix Ux_; + const mblas::Matrix Gamma_1_; + const mblas::Matrix Gamma_2_; }; ////////////////////////////////////////////////////////////////////////////// @@ -40,6 +42,7 @@ struct Weights { const mblas::Matrix Wi_; const mblas::Matrix Bi_; + const mblas::Matrix Gamma_; }; struct DecGRU2 { @@ -52,6 +55,8 @@ struct Weights { const mblas::Matrix Bx2_; const mblas::Matrix Bx1_; const mblas::Matrix Ux_; + const mblas::Matrix Gamma_1_; + const mblas::Matrix Gamma_2_; }; struct DecAttention { @@ -62,6 +67,8 @@ struct Weights { const mblas::Matrix B_; const mblas::Matrix U_; const mblas::Matrix C_; + const mblas::Matrix Gamma_1_; + const mblas::Matrix Gamma_2_; }; struct DecSoftmax { @@ -75,6 +82,9 @@ struct Weights { const mblas::Matrix B3_; const mblas::Matrix W4_; const mblas::Matrix B4_; + const mblas::Matrix Gamma_0_; + const mblas::Matrix Gamma_1_; + const mblas::Matrix Gamma_2_; }; //////////////////////////////////////////////////////////////////////////////