mirror of
https://github.com/marian-nmt/marian.git
synced 2024-11-27 10:33:14 +03:00
Add marian style tie-embs
This commit is contained in:
parent
9aee648fa1
commit
3cce98c356
@ -6,10 +6,6 @@ namespace amunmt {
|
||||
namespace CPU {
|
||||
namespace dl4mt {
|
||||
|
||||
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key)
|
||||
: E_(model[key])
|
||||
{}
|
||||
|
||||
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
|
||||
: E_(model.getFirstOfMany(keys))
|
||||
{}
|
||||
@ -67,8 +63,10 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
||||
B2_(model("ff_logit_prev_b", true)),
|
||||
W3_(model["ff_logit_ctx_W"]),
|
||||
B3_(model("ff_logit_ctx_b", true)),
|
||||
W4_(model.getFirstOfMany({std::pair<std::string, bool>(std::string("ff_logit_W"), false),
|
||||
std::make_pair(std::string("Wemb_dec"), true)})),
|
||||
W4_(model.getFirstOfMany({std::pair<std::string, bool>(
|
||||
std::string("ff_logit_W"), false),
|
||||
std::make_pair(std::string("Wemb_dec"), true),
|
||||
std::make_pair(std::string("Wemb"), true)})),
|
||||
B4_(model("ff_logit_b", true)),
|
||||
Gamma_0_(model["ff_logit_l1_gamma0"]),
|
||||
Gamma_1_(model["ff_logit_l1_gamma1"]),
|
||||
@ -78,12 +76,15 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Weights::Weights(const NpzConverter& model, size_t)
|
||||
: encEmbeddings_(model, "Wemb"),
|
||||
: encEmbeddings_(model, std::vector<std::pair<std::string, bool>>({
|
||||
std::make_pair(std::string("Wemb"), false),
|
||||
std::make_pair(std::string("Wemb_dec"), false)})),
|
||||
encForwardGRU_(model, {"encoder_W", "encoder_b", "encoder_U", "encoder_Wx", "encoder_bx",
|
||||
"encoder_Ux", "encoder_gamma1", "encoder_gamma2"}),
|
||||
encBackwardGRU_(model, {"encoder_r_W", "encoder_r_b", "encoder_r_U", "encoder_r_Wx",
|
||||
"encoder_r_bx", "encoder_r_Ux", "encoder_r_gamma1", "encoder_r_gamma2"}),
|
||||
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>({std::make_pair(std::string("Wemb_dec"), false),
|
||||
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>({
|
||||
std::make_pair(std::string("Wemb_dec"), false),
|
||||
std::make_pair(std::string("Wemb"), false)})),
|
||||
decInit_(model),
|
||||
decGru1_(model, {"decoder_W", "decoder_b", "decoder_U", "decoder_Wx", "decoder_bx", "decoder_Ux",
|
||||
|
@ -16,7 +16,6 @@ struct Weights {
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Embeddings {
|
||||
Embeddings(const NpzConverter& model, const std::string &key);
|
||||
Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
|
||||
|
||||
const mblas::Matrix E_;
|
||||
|
@ -60,10 +60,6 @@ std::string Weights::Transition::name(const std::string& prefix, std::string nam
|
||||
return prefix + name + infix + "_drt_" + std::to_string(index) + suffix;
|
||||
}
|
||||
|
||||
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::string &key)
|
||||
: E_(model[key])
|
||||
{}
|
||||
|
||||
Weights::Embeddings::Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys)
|
||||
: E_(model.getFirstOfMany(keys))
|
||||
{}
|
||||
@ -143,7 +139,8 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
||||
W3_(model["ff_logit_ctx_W"]),
|
||||
B3_(model("ff_logit_ctx_b", true)),
|
||||
W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
|
||||
std::make_pair(std::string("Wemb_dec"), true)})),
|
||||
std::make_pair(std::string("Wemb_dec"), true),
|
||||
std::make_pair(std::string("Wemb"), true)})),
|
||||
B4_(model("ff_logit_b", true)),
|
||||
lns_1_(model["ff_logit_lstm_ln_s"]),
|
||||
lns_2_(model["ff_logit_prev_ln_s"]),
|
||||
@ -156,7 +153,9 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Weights::Weights(const NpzConverter& model, size_t)
|
||||
: encEmbeddings_(model, "Wemb"),
|
||||
: encEmbeddings_(model, std::vector<std::pair<std::string, bool>>(
|
||||
{std::make_pair(std::string("Wemb"), false),
|
||||
std::make_pair(std::string("Wemb_dec"), false)})),
|
||||
decEmbeddings_(model, std::vector<std::pair<std::string, bool>>(
|
||||
{std::make_pair(std::string("Wemb_dec"), false),
|
||||
std::make_pair(std::string("Wemb"), false)})),
|
||||
|
@ -49,7 +49,6 @@ struct Weights {
|
||||
};
|
||||
|
||||
struct Embeddings {
|
||||
Embeddings(const NpzConverter& model, const std::string &key);
|
||||
Embeddings(const NpzConverter& model, const std::vector<std::pair<std::string, bool>> keys);
|
||||
|
||||
const mblas::Matrix E_;
|
||||
|
@ -5,7 +5,8 @@ namespace GPU {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
Weights::EncEmbeddings::EncEmbeddings(const NpzConverter& model)
|
||||
: E_(model.get("Wemb", true))
|
||||
: E_(model.getFirstOfMany({std::make_pair("Wemb", false),
|
||||
std::make_pair("Wemb_dec", false)}, true))
|
||||
{}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -141,7 +142,8 @@ Weights::DecSoftmax::DecSoftmax(const NpzConverter& model)
|
||||
W3_(model.get("ff_logit_ctx_W", true)),
|
||||
B3_(model.get("ff_logit_ctx_b", true, true)),
|
||||
W4_(model.getFirstOfMany({std::make_pair(std::string("ff_logit_W"), false),
|
||||
std::make_pair(std::string("Wemb_dec"), true)}, true)),
|
||||
std::make_pair(std::string("Wemb_dec"), true),
|
||||
std::make_pair(std::string("Wemb"), true)}, true)),
|
||||
B4_(model.get("ff_logit_b", true, true)),
|
||||
Gamma_0_(model.get("ff_logit_l1_gamma0", false)),
|
||||
Gamma_1_(model.get("ff_logit_l1_gamma1", false)),
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit f0f06468e1633af503b81eaa8b1fbca881b95257
|
||||
Subproject commit 7f3802e97f08e91382c6609cd772c4e80488b4d8
|
Loading…
Reference in New Issue
Block a user