mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-03 15:52:00 +03:00
Add layer normalization matrices to model
This commit is contained in:
parent
1d86121c76
commit
d3615d2481
@ -33,7 +33,9 @@ struct Weights {
|
|||||||
Wx_(model["encoder_Wx"]),
|
Wx_(model["encoder_Wx"]),
|
||||||
Bx1_(model("encoder_bx", true)),
|
Bx1_(model("encoder_bx", true)),
|
||||||
Bx2_(Bx1_.Rows(), Bx1_.Cols(), 0.0f),
|
Bx2_(Bx1_.Rows(), Bx1_.Cols(), 0.0f),
|
||||||
Ux_(model["encoder_Ux"])
|
Ux_(model["encoder_Ux"]),
|
||||||
|
Gamma_1_(model["encoder_gamma1"]),
|
||||||
|
Gamma_2_(model["encoder_gamma2"])
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
const mblas::Matrix W_;
|
const mblas::Matrix W_;
|
||||||
@ -43,6 +45,8 @@ struct Weights {
|
|||||||
const mblas::Matrix Bx1_;
|
const mblas::Matrix Bx1_;
|
||||||
const mblas::Matrix Bx2_;
|
const mblas::Matrix Bx2_;
|
||||||
const mblas::Matrix Ux_;
|
const mblas::Matrix Ux_;
|
||||||
|
const mblas::Matrix Gamma_1_;
|
||||||
|
const mblas::Matrix Gamma_2_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct EncBackwardGRU {
|
struct EncBackwardGRU {
|
||||||
@ -55,7 +59,9 @@ struct Weights {
|
|||||||
Wx_(model["encoder_r_Wx"]),
|
Wx_(model["encoder_r_Wx"]),
|
||||||
Bx1_(model("encoder_r_bx", true)),
|
Bx1_(model("encoder_r_bx", true)),
|
||||||
Bx2_(Bx1_.Rows(), Bx1_.Cols(), 0.0f),
|
Bx2_(Bx1_.Rows(), Bx1_.Cols(), 0.0f),
|
||||||
Ux_(model["encoder_r_Ux"])
|
Ux_(model["encoder_r_Ux"]),
|
||||||
|
Gamma_1_(model["encoder_r_gamma1"]),
|
||||||
|
Gamma_2_(model["encoder_r_gamma2"])
|
||||||
{}
|
{}
|
||||||
|
|
||||||
const mblas::Matrix W_;
|
const mblas::Matrix W_;
|
||||||
@ -65,6 +71,8 @@ struct Weights {
|
|||||||
const mblas::Matrix Bx1_;
|
const mblas::Matrix Bx1_;
|
||||||
const mblas::Matrix Bx2_;
|
const mblas::Matrix Bx2_;
|
||||||
const mblas::Matrix Ux_;
|
const mblas::Matrix Ux_;
|
||||||
|
const mblas::Matrix Gamma_1_;
|
||||||
|
const mblas::Matrix Gamma_2_;
|
||||||
};
|
};
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
@ -84,11 +92,13 @@ struct Weights {
|
|||||||
|
|
||||||
DecInit(const NpzConverter& model)
|
DecInit(const NpzConverter& model)
|
||||||
: Wi_(model["ff_state_W"]),
|
: Wi_(model["ff_state_W"]),
|
||||||
Bi_(model("ff_state_b", true))
|
Bi_(model("ff_state_b", true)),
|
||||||
|
Gamma_(model["ff_state_gamma"])
|
||||||
{}
|
{}
|
||||||
|
|
||||||
const mblas::Matrix Wi_;
|
const mblas::Matrix Wi_;
|
||||||
const mblas::Matrix Bi_;
|
const mblas::Matrix Bi_;
|
||||||
|
const mblas::Matrix Gamma_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DecGRU1 {
|
struct DecGRU1 {
|
||||||
@ -101,7 +111,9 @@ struct Weights {
|
|||||||
Wx_(model["decoder_Wx"]),
|
Wx_(model["decoder_Wx"]),
|
||||||
Bx1_(model("decoder_bx", true)),
|
Bx1_(model("decoder_bx", true)),
|
||||||
Bx2_(Bx1_.Rows(), Bx1_.Cols(), 0.0f),
|
Bx2_(Bx1_.Rows(), Bx1_.Cols(), 0.0f),
|
||||||
Ux_(model["decoder_Ux"])
|
Ux_(model["decoder_Ux"]),
|
||||||
|
Gamma_1_(model["decoder_cell1_gamma1"]),
|
||||||
|
Gamma_2_(model["decoder_cell1_gamma2"])
|
||||||
{}
|
{}
|
||||||
|
|
||||||
const mblas::Matrix W_;
|
const mblas::Matrix W_;
|
||||||
@ -111,6 +123,8 @@ struct Weights {
|
|||||||
const mblas::Matrix Bx1_;
|
const mblas::Matrix Bx1_;
|
||||||
const mblas::Matrix Bx2_;
|
const mblas::Matrix Bx2_;
|
||||||
const mblas::Matrix Ux_;
|
const mblas::Matrix Ux_;
|
||||||
|
const mblas::Matrix Gamma_1_;
|
||||||
|
const mblas::Matrix Gamma_2_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DecGRU2 {
|
struct DecGRU2 {
|
||||||
@ -123,7 +137,9 @@ struct Weights {
|
|||||||
Wx_(model["decoder_Wcx"]),
|
Wx_(model["decoder_Wcx"]),
|
||||||
Bx2_(model("decoder_bx_nl", true)),
|
Bx2_(model("decoder_bx_nl", true)),
|
||||||
Bx1_(Bx2_.Rows(), Bx2_.Cols(), 0.0f),
|
Bx1_(Bx2_.Rows(), Bx2_.Cols(), 0.0f),
|
||||||
Ux_(model["decoder_Ux_nl"])
|
Ux_(model["decoder_Ux_nl"]),
|
||||||
|
Gamma_1_(model["decoder_cell2_gamma1"]),
|
||||||
|
Gamma_2_(model["decoder_cell2_gamma2"])
|
||||||
{}
|
{}
|
||||||
|
|
||||||
const mblas::Matrix W_;
|
const mblas::Matrix W_;
|
||||||
@ -133,6 +149,8 @@ struct Weights {
|
|||||||
const mblas::Matrix Bx2_;
|
const mblas::Matrix Bx2_;
|
||||||
const mblas::Matrix Bx1_;
|
const mblas::Matrix Bx1_;
|
||||||
const mblas::Matrix Ux_;
|
const mblas::Matrix Ux_;
|
||||||
|
const mblas::Matrix Gamma_1_;
|
||||||
|
const mblas::Matrix Gamma_2_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DecAlignment {
|
struct DecAlignment {
|
||||||
@ -143,7 +161,9 @@ struct Weights {
|
|||||||
W_(model["decoder_W_comb_att"]),
|
W_(model["decoder_W_comb_att"]),
|
||||||
B_(model("decoder_b_att", true)),
|
B_(model("decoder_b_att", true)),
|
||||||
U_(model["decoder_Wc_att"]),
|
U_(model["decoder_Wc_att"]),
|
||||||
C_(model["decoder_c_tt"]) // scalar?
|
C_(model["decoder_c_tt"]), // scalar?
|
||||||
|
Gamma_1_(model["decoder_att_gamma1"]),
|
||||||
|
Gamma_2_(model["decoder_att_gamma2"])
|
||||||
{}
|
{}
|
||||||
|
|
||||||
const mblas::Matrix V_;
|
const mblas::Matrix V_;
|
||||||
@ -151,6 +171,8 @@ struct Weights {
|
|||||||
const mblas::Matrix B_;
|
const mblas::Matrix B_;
|
||||||
const mblas::Matrix U_;
|
const mblas::Matrix U_;
|
||||||
const mblas::Matrix C_;
|
const mblas::Matrix C_;
|
||||||
|
const mblas::Matrix Gamma_1_;
|
||||||
|
const mblas::Matrix Gamma_2_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DecSoftmax {
|
struct DecSoftmax {
|
||||||
@ -164,7 +186,10 @@ struct Weights {
|
|||||||
W3_(model["ff_logit_ctx_W"]),
|
W3_(model["ff_logit_ctx_W"]),
|
||||||
B3_(model("ff_logit_ctx_b", true)),
|
B3_(model("ff_logit_ctx_b", true)),
|
||||||
W4_(model["ff_logit_W"]),
|
W4_(model["ff_logit_W"]),
|
||||||
B4_(model("ff_logit_b", true))
|
B4_(model("ff_logit_b", true)),
|
||||||
|
Gamma_0_(model["ff_logit_l1_gamma0"]),
|
||||||
|
Gamma_1_(model["ff_logit_l1_gamma1"]),
|
||||||
|
Gamma_2_(model["ff_logit_l1_gamma2"])
|
||||||
{}
|
{}
|
||||||
|
|
||||||
const mblas::Matrix W1_;
|
const mblas::Matrix W1_;
|
||||||
@ -175,6 +200,9 @@ struct Weights {
|
|||||||
const mblas::Matrix B3_;
|
const mblas::Matrix B3_;
|
||||||
const mblas::Matrix W4_;
|
const mblas::Matrix W4_;
|
||||||
const mblas::Matrix B4_;
|
const mblas::Matrix B4_;
|
||||||
|
const mblas::Matrix Gamma_0_;
|
||||||
|
const mblas::Matrix Gamma_1_;
|
||||||
|
const mblas::Matrix Gamma_2_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Weights(const std::string& npzFile, size_t device)
|
Weights(const std::string& npzFile, size_t device)
|
||||||
|
Loading…
Reference in New Issue
Block a user