mirror of
https://github.com/marian-nmt/marian.git
synced 2024-12-02 16:39:38 +03:00
Add layer normalization matrices to model
This commit is contained in:
parent
1d86121c76
commit
d3615d2481
@ -33,7 +33,9 @@ struct Weights {
|
||||
Wx_(model["encoder_Wx"]),
|
||||
Bx1_(model("encoder_bx", true)),
|
||||
Bx2_(Bx1_.Rows(), Bx1_.Cols(), 0.0f),
|
||||
Ux_(model["encoder_Ux"])
|
||||
Ux_(model["encoder_Ux"]),
|
||||
Gamma_1_(model["encoder_gamma1"]),
|
||||
Gamma_2_(model["encoder_gamma2"])
|
||||
{ }
|
||||
|
||||
const mblas::Matrix W_;
|
||||
@ -43,6 +45,8 @@ struct Weights {
|
||||
const mblas::Matrix Bx1_;
|
||||
const mblas::Matrix Bx2_;
|
||||
const mblas::Matrix Ux_;
|
||||
const mblas::Matrix Gamma_1_;
|
||||
const mblas::Matrix Gamma_2_;
|
||||
};
|
||||
|
||||
struct EncBackwardGRU {
|
||||
@ -55,7 +59,9 @@ struct Weights {
|
||||
Wx_(model["encoder_r_Wx"]),
|
||||
Bx1_(model("encoder_r_bx", true)),
|
||||
Bx2_(Bx1_.Rows(), Bx1_.Cols(), 0.0f),
|
||||
Ux_(model["encoder_r_Ux"])
|
||||
Ux_(model["encoder_r_Ux"]),
|
||||
Gamma_1_(model["encoder_r_gamma1"]),
|
||||
Gamma_2_(model["encoder_r_gamma2"])
|
||||
{}
|
||||
|
||||
const mblas::Matrix W_;
|
||||
@ -65,6 +71,8 @@ struct Weights {
|
||||
const mblas::Matrix Bx1_;
|
||||
const mblas::Matrix Bx2_;
|
||||
const mblas::Matrix Ux_;
|
||||
const mblas::Matrix Gamma_1_;
|
||||
const mblas::Matrix Gamma_2_;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
@ -84,11 +92,13 @@ struct Weights {
|
||||
|
||||
DecInit(const NpzConverter& model)
|
||||
: Wi_(model["ff_state_W"]),
|
||||
Bi_(model("ff_state_b", true))
|
||||
Bi_(model("ff_state_b", true)),
|
||||
Gamma_(model["ff_state_gamma"])
|
||||
{}
|
||||
|
||||
const mblas::Matrix Wi_;
|
||||
const mblas::Matrix Bi_;
|
||||
const mblas::Matrix Gamma_;
|
||||
};
|
||||
|
||||
struct DecGRU1 {
|
||||
@ -101,7 +111,9 @@ struct Weights {
|
||||
Wx_(model["decoder_Wx"]),
|
||||
Bx1_(model("decoder_bx", true)),
|
||||
Bx2_(Bx1_.Rows(), Bx1_.Cols(), 0.0f),
|
||||
Ux_(model["decoder_Ux"])
|
||||
Ux_(model["decoder_Ux"]),
|
||||
Gamma_1_(model["decoder_cell1_gamma1"]),
|
||||
Gamma_2_(model["decoder_cell1_gamma2"])
|
||||
{}
|
||||
|
||||
const mblas::Matrix W_;
|
||||
@ -111,6 +123,8 @@ struct Weights {
|
||||
const mblas::Matrix Bx1_;
|
||||
const mblas::Matrix Bx2_;
|
||||
const mblas::Matrix Ux_;
|
||||
const mblas::Matrix Gamma_1_;
|
||||
const mblas::Matrix Gamma_2_;
|
||||
};
|
||||
|
||||
struct DecGRU2 {
|
||||
@ -123,7 +137,9 @@ struct Weights {
|
||||
Wx_(model["decoder_Wcx"]),
|
||||
Bx2_(model("decoder_bx_nl", true)),
|
||||
Bx1_(Bx2_.Rows(), Bx2_.Cols(), 0.0f),
|
||||
Ux_(model["decoder_Ux_nl"])
|
||||
Ux_(model["decoder_Ux_nl"]),
|
||||
Gamma_1_(model["decoder_cell2_gamma1"]),
|
||||
Gamma_2_(model["decoder_cell2_gamma2"])
|
||||
{}
|
||||
|
||||
const mblas::Matrix W_;
|
||||
@ -133,6 +149,8 @@ struct Weights {
|
||||
const mblas::Matrix Bx2_;
|
||||
const mblas::Matrix Bx1_;
|
||||
const mblas::Matrix Ux_;
|
||||
const mblas::Matrix Gamma_1_;
|
||||
const mblas::Matrix Gamma_2_;
|
||||
};
|
||||
|
||||
struct DecAlignment {
|
||||
@ -143,7 +161,9 @@ struct Weights {
|
||||
W_(model["decoder_W_comb_att"]),
|
||||
B_(model("decoder_b_att", true)),
|
||||
U_(model["decoder_Wc_att"]),
|
||||
C_(model["decoder_c_tt"]) // scalar?
|
||||
C_(model["decoder_c_tt"]), // scalar?
|
||||
Gamma_1_(model["decoder_att_gamma1"]),
|
||||
Gamma_2_(model["decoder_att_gamma2"])
|
||||
{}
|
||||
|
||||
const mblas::Matrix V_;
|
||||
@ -151,6 +171,8 @@ struct Weights {
|
||||
const mblas::Matrix B_;
|
||||
const mblas::Matrix U_;
|
||||
const mblas::Matrix C_;
|
||||
const mblas::Matrix Gamma_1_;
|
||||
const mblas::Matrix Gamma_2_;
|
||||
};
|
||||
|
||||
struct DecSoftmax {
|
||||
@ -164,7 +186,10 @@ struct Weights {
|
||||
W3_(model["ff_logit_ctx_W"]),
|
||||
B3_(model("ff_logit_ctx_b", true)),
|
||||
W4_(model["ff_logit_W"]),
|
||||
B4_(model("ff_logit_b", true))
|
||||
B4_(model("ff_logit_b", true)),
|
||||
Gamma_0_(model["ff_logit_l1_gamma0"]),
|
||||
Gamma_1_(model["ff_logit_l1_gamma1"]),
|
||||
Gamma_2_(model["ff_logit_l1_gamma2"])
|
||||
{}
|
||||
|
||||
const mblas::Matrix W1_;
|
||||
@ -175,6 +200,9 @@ struct Weights {
|
||||
const mblas::Matrix B3_;
|
||||
const mblas::Matrix W4_;
|
||||
const mblas::Matrix B4_;
|
||||
const mblas::Matrix Gamma_0_;
|
||||
const mblas::Matrix Gamma_1_;
|
||||
const mblas::Matrix Gamma_2_;
|
||||
};
|
||||
|
||||
Weights(const std::string& npzFile, size_t device)
|
||||
|
Loading…
Reference in New Issue
Block a user