mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-30 21:39:52 +03:00
Add layer normalization matrices to cpu model
This commit is contained in:
parent
3fb9faedd1
commit
73a71a1fcb
@ -25,7 +25,8 @@ Weights::GRU::GRU(const NpzConverter& model, const std::vector<std::string> &key
|
||||
|
||||
Weights::DecInit::DecInit(const NpzConverter& model)
|
||||
: Wi_(model["ff_state_W"]),
|
||||
Bi_(model("ff_state_b", true))
|
||||
Bi_(model("ff_state_b", true)),
|
||||
Gamma_(model["ff_state_gamma"])
|
||||
{}
|
||||
|
||||
Weights::DecGRU2::DecGRU2(const NpzConverter& model)
|
||||
@ -35,17 +36,21 @@ Weights::DecGRU2::DecGRU2(const NpzConverter& model)
|
||||
Wx_(model["decoder_Wcx"]),
|
||||
Bx2_(model("decoder_bx_nl", true)),
|
||||
Bx1_(Bx2_.rows(), Bx2_.columns()),
|
||||
Ux_(model["decoder_Ux_nl"])
|
||||
Ux_(model["decoder_Ux_nl"]),
|
||||
Gamma_1_(model["decoder_cell2_gamma1"]),
|
||||
Gamma_2_(model["decoder_cell2_gamma2"])
|
||||
{
|
||||
const_cast<mblas::Matrix&>(Bx1_) = 0.0f;
|
||||
}
|
||||
|
||||
Weights::DecAttention::DecAttention(const NpzConverter& model)
|
||||
: V_(model("decoder_U_att", true)),
|
||||
W_(model["decoder_W_comb_att"]),
|
||||
B_(model("decoder_b_att", true)),
|
||||
U_(model["decoder_Wc_att"]),
|
||||
C_(model["decoder_c_tt"]) // scalar?
|
||||
W_(model["decoder_W_comb_att"]),
|
||||
B_(model("decoder_b_att", true)),
|
||||
U_(model["decoder_Wc_att"]),
|
||||
C_(model["decoder_c_tt"]), // scalar?
|
||||
Gamma_1_(model["decoder_att_gamma1"]),
|
||||
Gamma_2_(model["decoder_att_gamma2"])
|
||||
{}
|
||||
|
||||
Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
||||
@ -56,25 +61,28 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
||||
W3_(model["ff_logit_ctx_W"]),
|
||||
B3_(model("ff_logit_ctx_b", true)),
|
||||
W4_(model["ff_logit_W"]),
|
||||
B4_(model("ff_logit_b", true))
|
||||
B4_(model("ff_logit_b", true)),
|
||||
Gamma_0_(model["ff_logit_l1_gamma0"]),
|
||||
Gamma_1_(model["ff_logit_l1_gamma1"]),
|
||||
Gamma_2_(model["ff_logit_l1_gamma2"])
|
||||
{}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Weights::Weights(const NpzConverter& model, size_t)
|
||||
: encEmbeddings_(model, "Wemb"),
|
||||
encForwardGRU_(model, {"encoder_W", "encoder_b", "encoder_U", "encoder_Wx", "encoder_bx", "encoder_Ux"}),
|
||||
encBackwardGRU_(model, {"encoder_r_W", "encoder_r_b", "encoder_r_U", "encoder_r_Wx", "encoder_r_bx", "encoder_r_Ux"}),
|
||||
decEmbeddings_(model, "Wemb_dec"),
|
||||
decInit_(model),
|
||||
decGru1_(model, {"decoder_W", "decoder_b", "decoder_U", "decoder_Wx", "decoder_bx", "decoder_Ux"}),
|
||||
decGru2_(model),
|
||||
decAttention_(model),
|
||||
decSoftmax_(model)
|
||||
{
|
||||
//cerr << *this << endl;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
encForwardGRU_(model, {"encoder_W", "encoder_b", "encoder_U", "encoder_Wx", "encoder_bx",
|
||||
"encoder_Ux", "encoder_gamma1", "encoder_gamma2"}),
|
||||
encBackwardGRU_(model, {"encoder_r_W", "encoder_r_b", "encoder_r_U", "encoder_r_Wx",
|
||||
"encoder_r_bx", "encoder_r_Ux", "encoder_r_gamma1", "encoder_r_gamma2"}),
|
||||
decEmbeddings_(model, "Wemb_dec"),
|
||||
decInit_(model),
|
||||
decGru1_(model, {"decoder_W", "decoder_b", "decoder_U", "decoder_Wx", "decoder_bx", "decoder_Ux",
|
||||
"decoder_cell1_gamma1", "decoder_cell1_gamma2"}),
|
||||
decGru2_(model),
|
||||
decAttention_(model),
|
||||
decSoftmax_(model)
|
||||
{}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace amunmt
|
||||
|
@ -31,6 +31,8 @@ struct Weights {
|
||||
const mblas::Matrix Bx1_;
|
||||
const mblas::Matrix Bx2_;
|
||||
const mblas::Matrix Ux_;
|
||||
const mblas::Matrix Gamma_1_;
|
||||
const mblas::Matrix Gamma_2_;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
@ -40,6 +42,7 @@ struct Weights {
|
||||
|
||||
const mblas::Matrix Wi_;
|
||||
const mblas::Matrix Bi_;
|
||||
const mblas::Matrix Gamma_;
|
||||
};
|
||||
|
||||
struct DecGRU2 {
|
||||
@ -52,6 +55,8 @@ struct Weights {
|
||||
const mblas::Matrix Bx2_;
|
||||
const mblas::Matrix Bx1_;
|
||||
const mblas::Matrix Ux_;
|
||||
const mblas::Matrix Gamma_1_;
|
||||
const mblas::Matrix Gamma_2_;
|
||||
};
|
||||
|
||||
struct DecAttention {
|
||||
@ -62,6 +67,8 @@ struct Weights {
|
||||
const mblas::Matrix B_;
|
||||
const mblas::Matrix U_;
|
||||
const mblas::Matrix C_;
|
||||
const mblas::Matrix Gamma_1_;
|
||||
const mblas::Matrix Gamma_2_;
|
||||
};
|
||||
|
||||
struct DecSoftmax {
|
||||
@ -75,6 +82,9 @@ struct Weights {
|
||||
const mblas::Matrix B3_;
|
||||
const mblas::Matrix W4_;
|
||||
const mblas::Matrix B4_;
|
||||
const mblas::Matrix Gamma_0_;
|
||||
const mblas::Matrix Gamma_1_;
|
||||
const mblas::Matrix Gamma_2_;
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
Loading…
Reference in New Issue
Block a user